sp1_gpu_prover/
components.rs1use std::{collections::BTreeMap, sync::Arc};
2
3use sp1_gpu_air::{air_block::BlockAir, codegen_cuda_eval, SymbolicProverFolder};
4use sp1_gpu_challenger::{DuplexChallenger, MultiField32Challenger};
5use sp1_gpu_cudart::{PinnedBuffer, TaskScope};
6
7use slop_basefold::BasefoldVerifier;
8use slop_bn254::Bn254Fr;
9use slop_challenger::IopCtx;
10use slop_futures::queue::WorkerQueue;
11use sp1_core_machine::riscv::RiscvAir;
12use sp1_gpu_basefold::FriCudaProver;
13use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2Bn254CudaProver, Poseidon2SP1Field16CudaProver};
14use sp1_gpu_shard_prover::{CudaShardProver, CudaShardProverComponents};
15use sp1_gpu_tracegen::CudaTracegenAir;
16use sp1_hypercube::{air::MachineAir, prover::ZerocheckAir, SP1InnerPcs, SP1OuterPcs, SP1SC};
17use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
18use sp1_prover::{CompressAir, ReadyWrapProverBuilder, SP1ProverComponents, WrapAir};
19
20pub struct SP1CudaProverComponents;
21
22impl SP1ProverComponents for SP1CudaProverComponents {
23 type CoreProver = CudaShardProver<SP1GlobalContext, CudaProverCoreComponents>;
24 type RecursionProver = CudaShardProver<SP1GlobalContext, CudaProverRecursionComponents>;
25 type WrapProver = CudaShardProver<SP1OuterGlobalContext, CudaProverWrapComponents>;
26 type WrapProverBuilder = ReadyWrapProverBuilder<Self>;
27}
28
29pub struct CudaProverCoreComponents;
31
32impl CudaShardProverComponents<SP1GlobalContext> for CudaProverCoreComponents {
33 type P = Poseidon2SP1Field16CudaProver;
34 type Air = RiscvAir<SP1Field>;
35 type C = SP1InnerPcs;
36 type DeviceChallenger = DuplexChallenger<SP1Field, TaskScope>;
37}
38
39pub struct CudaProverRecursionComponents;
41
42impl CudaShardProverComponents<SP1GlobalContext> for CudaProverRecursionComponents {
43 type P = Poseidon2SP1Field16CudaProver;
44 type Air = CompressAir<<SP1GlobalContext as IopCtx>::F>;
45 type C = SP1InnerPcs;
46 type DeviceChallenger = DuplexChallenger<SP1Field, TaskScope>;
47}
48
49pub struct CudaProverWrapComponents;
51
52impl CudaShardProverComponents<SP1OuterGlobalContext> for CudaProverWrapComponents {
53 type P = Poseidon2Bn254CudaProver;
54 type Air = WrapAir<<SP1OuterGlobalContext as IopCtx>::F>;
55 type C = SP1OuterPcs;
56 type DeviceChallenger = MultiField32Challenger<SP1Field, Bn254Fr, TaskScope>;
57}
58
59pub async fn new_cuda_prover<GC, PC>(
60 verifier: &sp1_hypercube::MachineVerifier<GC, SP1SC<GC, PC::Air>>,
61 max_trace_size: usize,
62 num_workers: usize,
63 recompute_first_layer: bool,
64 scope: TaskScope,
65) -> CudaShardProver<GC, PC>
66where
67 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField>,
68 PC: CudaShardProverComponents<GC>,
69 PC::P: CudaTcsProver<GC>,
70 PC::Air: CudaTracegenAir<GC::F>
71 + for<'a> BlockAir<SymbolicProverFolder<'a>>
72 + ZerocheckAir<GC::F, GC::EF>
73 + std::fmt::Debug,
74{
75 let machine = verifier.machine().clone();
76
77 let mut cache = BTreeMap::new();
78 for chip in machine.chips() {
79 let result = codegen_cuda_eval(chip.air.as_ref());
80 cache.insert(chip.air.name().to_string(), result);
81 }
82
83 let log_stacking_height = verifier.log_stacking_height();
84 let max_log_row_count = verifier.max_log_row_count();
85
86 let basefold_verifier = BasefoldVerifier::<GC>::new(*verifier.fri_config(), 2);
88
89 let tcs_prover = PC::P::new(&scope);
90 let basefold_prover = FriCudaProver::<GC, PC::P, GC::F>::new(
91 tcs_prover,
92 basefold_verifier.fri_config,
93 log_stacking_height,
94 );
95
96 let mut all_interactions = BTreeMap::new();
97
98 for chip in machine.chips().iter() {
99 let host_interactions = sp1_gpu_logup_gkr::Interactions::new(chip.sends(), chip.receives());
100 let device_interactions = host_interactions.copy_to_device(&scope).unwrap();
101 all_interactions.insert(chip.name().to_string(), Arc::new(device_interactions));
102 }
103
104 let mut trace_buffers = Vec::with_capacity(num_workers);
105 for _ in 0..num_workers {
106 let pinned_buffer = PinnedBuffer::<GC::F>::with_capacity(max_trace_size);
107 trace_buffers.push(pinned_buffer);
108 }
109
110 let trace_buffers = Arc::new(WorkerQueue::new(trace_buffers));
111 CudaShardProver::<GC, PC>::new(
112 trace_buffers,
113 max_log_row_count as u32,
114 basefold_prover,
115 machine,
116 max_trace_size,
117 scope,
118 all_interactions,
119 cache,
120 recompute_first_layer,
121 recompute_first_layer,
122 )
123}