Skip to main content

sp1_gpu_prover/
components.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use sp1_gpu_air::{air_block::BlockAir, codegen_cuda_eval, SymbolicProverFolder};
4use sp1_gpu_challenger::{DuplexChallenger, MultiField32Challenger};
5use sp1_gpu_cudart::{PinnedBuffer, TaskScope};
6
7use slop_basefold::BasefoldVerifier;
8use slop_bn254::Bn254Fr;
9use slop_challenger::IopCtx;
10use slop_futures::queue::WorkerQueue;
11use sp1_core_machine::riscv::RiscvAir;
12use sp1_gpu_basefold::FriCudaProver;
13use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2Bn254CudaProver, Poseidon2SP1Field16CudaProver};
14use sp1_gpu_shard_prover::{CudaShardProver, CudaShardProverComponents};
15use sp1_gpu_tracegen::CudaTracegenAir;
16use sp1_hypercube::{air::MachineAir, prover::ZerocheckAir, SP1InnerPcs, SP1OuterPcs, SP1SC};
17use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
18use sp1_prover::{CompressAir, ReadyWrapProverBuilder, SP1ProverComponents, WrapAir};
19
20pub struct SP1CudaProverComponents;
21
22impl SP1ProverComponents for SP1CudaProverComponents {
23    type CoreProver = CudaShardProver<SP1GlobalContext, CudaProverCoreComponents>;
24    type RecursionProver = CudaShardProver<SP1GlobalContext, CudaProverRecursionComponents>;
25    type WrapProver = CudaShardProver<SP1OuterGlobalContext, CudaProverWrapComponents>;
26    type WrapProverBuilder = ReadyWrapProverBuilder<Self>;
27}
28
29/// Core prover components for the CUDA prover.
30pub struct CudaProverCoreComponents;
31
32impl CudaShardProverComponents<SP1GlobalContext> for CudaProverCoreComponents {
33    type P = Poseidon2SP1Field16CudaProver;
34    type Air = RiscvAir<SP1Field>;
35    type C = SP1InnerPcs;
36    type DeviceChallenger = DuplexChallenger<SP1Field, TaskScope>;
37}
38
39/// Recursion prover components for the CUDA prover.
40pub struct CudaProverRecursionComponents;
41
42impl CudaShardProverComponents<SP1GlobalContext> for CudaProverRecursionComponents {
43    type P = Poseidon2SP1Field16CudaProver;
44    type Air = CompressAir<<SP1GlobalContext as IopCtx>::F>;
45    type C = SP1InnerPcs;
46    type DeviceChallenger = DuplexChallenger<SP1Field, TaskScope>;
47}
48
49/// Wrap prover components for the CUDA prover.
50pub struct CudaProverWrapComponents;
51
52impl CudaShardProverComponents<SP1OuterGlobalContext> for CudaProverWrapComponents {
53    type P = Poseidon2Bn254CudaProver;
54    type Air = WrapAir<<SP1OuterGlobalContext as IopCtx>::F>;
55    type C = SP1OuterPcs;
56    type DeviceChallenger = MultiField32Challenger<SP1Field, Bn254Fr, TaskScope>;
57}
58
59pub async fn new_cuda_prover<GC, PC>(
60    verifier: &sp1_hypercube::MachineVerifier<GC, SP1SC<GC, PC::Air>>,
61    max_trace_size: usize,
62    num_workers: usize,
63    recompute_first_layer: bool,
64    scope: TaskScope,
65) -> CudaShardProver<GC, PC>
66where
67    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField>,
68    PC: CudaShardProverComponents<GC>,
69    PC::P: CudaTcsProver<GC>,
70    PC::Air: CudaTracegenAir<GC::F>
71        + for<'a> BlockAir<SymbolicProverFolder<'a>>
72        + ZerocheckAir<GC::F, GC::EF>
73        + std::fmt::Debug,
74{
75    let machine = verifier.machine().clone();
76
77    let mut cache = BTreeMap::new();
78    for chip in machine.chips() {
79        let result = codegen_cuda_eval(chip.air.as_ref());
80        cache.insert(chip.air.name().to_string(), result);
81    }
82
83    let log_stacking_height = verifier.log_stacking_height();
84    let max_log_row_count = verifier.max_log_row_count();
85
86    // Create the basefold prover from the verifier's PCS config
87    let basefold_verifier = BasefoldVerifier::<GC>::new(*verifier.fri_config(), 2);
88
89    let tcs_prover = PC::P::new(&scope);
90    let basefold_prover = FriCudaProver::<GC, PC::P, GC::F>::new(
91        tcs_prover,
92        basefold_verifier.fri_config,
93        log_stacking_height,
94    );
95
96    let mut all_interactions = BTreeMap::new();
97
98    for chip in machine.chips().iter() {
99        let host_interactions = sp1_gpu_logup_gkr::Interactions::new(chip.sends(), chip.receives());
100        let device_interactions = host_interactions.copy_to_device(&scope).unwrap();
101        all_interactions.insert(chip.name().to_string(), Arc::new(device_interactions));
102    }
103
104    let mut trace_buffers = Vec::with_capacity(num_workers);
105    for _ in 0..num_workers {
106        let pinned_buffer = PinnedBuffer::<GC::F>::with_capacity(max_trace_size);
107        trace_buffers.push(pinned_buffer);
108    }
109
110    let trace_buffers = Arc::new(WorkerQueue::new(trace_buffers));
111    CudaShardProver::<GC, PC>::new(
112        trace_buffers,
113        max_log_row_count as u32,
114        basefold_prover,
115        machine,
116        max_trace_size,
117        scope,
118        all_interactions,
119        cache,
120        recompute_first_layer,
121        recompute_first_layer,
122    )
123}