Skip to main content

webgpu_groth16/prover/
gpu_key.rs

1//! Persistent GPU proving key — pre-uploaded base point buffers.
2//!
3//! Holds interleaved GLV bases (G1) and direct bases (G2) on the GPU,
4//! already converted to Montgomery form. Reused across multiple proofs
5//! to eliminate per-proof base uploads and Montgomery conversion.
6
7use super::prepared_key::{PreparedProvingKey, interleave_glv_bases};
8use crate::gpu::GpuContext;
9use crate::gpu::curve::GpuCurve;
10
11/// Pre-uploaded GPU base point buffers for a specific circuit.
12///
13/// Created once per circuit via [`prepare_gpu_proving_key`].
14pub struct GpuProvingKey<G: GpuCurve> {
15    pub(crate) h_len: usize,
16    pub(crate) alpha_g1: G::G1Affine,
17    pub(crate) beta_g1: G::G1Affine,
18    pub(crate) beta_g2: G::G2Affine,
19    pub(crate) delta_g1: G::G1Affine,
20    pub(crate) delta_g2: G::G2Affine,
21    pub(crate) a_bases_buf: wgpu::Buffer,
22    pub(crate) b_g1_bases_buf: wgpu::Buffer,
23    pub(crate) l_bases_buf: wgpu::Buffer,
24    pub(crate) h_bases_buf: wgpu::Buffer,
25    pub(crate) b_g2_bases_buf: wgpu::Buffer,
26}
27
28/// Upload proving key bases to the GPU and convert to Montgomery form (one-time
29/// cost).
30pub fn prepare_gpu_proving_key<G: GpuCurve>(
31    ppk: &PreparedProvingKey<G>,
32    gpu: &GpuContext<G>,
33) -> GpuProvingKey<G> {
34    let a_combined;
35    let b_g1_combined;
36    let l_combined;
37    let h_combined;
38    if G::HAS_G1_GLV {
39        // Interleave G1 bases: [P₀, φ(P₀), P₁, φ(P₁), ...]
40        a_combined = interleave_glv_bases(
41            &ppk.a_bytes,
42            ppk.a_phi_bytes
43                .as_deref()
44                .expect("HAS_G1_GLV requires a_phi_bytes"),
45            G::G1_GPU_BYTES,
46        );
47        b_g1_combined = interleave_glv_bases(
48            &ppk.b_g1_bytes,
49            ppk.b_g1_phi_bytes
50                .as_deref()
51                .expect("HAS_G1_GLV requires b_g1_phi_bytes"),
52            G::G1_GPU_BYTES,
53        );
54        l_combined = interleave_glv_bases(
55            &ppk.l_bytes,
56            ppk.l_phi_bytes
57                .as_deref()
58                .expect("HAS_G1_GLV requires l_phi_bytes"),
59            G::G1_GPU_BYTES,
60        );
61        h_combined = interleave_glv_bases(
62            &ppk.h_bytes,
63            ppk.h_phi_bytes
64                .as_deref()
65                .expect("HAS_G1_GLV requires h_phi_bytes"),
66            G::G1_GPU_BYTES,
67        );
68    } else {
69        a_combined = ppk.a_bytes.clone();
70        b_g1_combined = ppk.b_g1_bytes.clone();
71        l_combined = ppk.l_bytes.clone();
72        h_combined = ppk.h_bytes.clone();
73    }
74
75    // Upload to GPU
76    let a_bases_buf = gpu.create_storage_buffer("gpk_a_bases", &a_combined);
77    let b_g1_bases_buf =
78        gpu.create_storage_buffer("gpk_b1_bases", &b_g1_combined);
79    let l_bases_buf = gpu.create_storage_buffer("gpk_l_bases", &l_combined);
80    let h_bases_buf = gpu.create_storage_buffer("gpk_h_bases", &h_combined);
81    let b_g2_bases_buf =
82        gpu.create_storage_buffer("gpk_b2_bases", &ppk.b_g2_bytes);
83
84    // Convert all bases to Montgomery form on GPU (one-time)
85    gpu.convert_to_montgomery(&a_bases_buf, false);
86    gpu.convert_to_montgomery(&b_g1_bases_buf, false);
87    gpu.convert_to_montgomery(&l_bases_buf, false);
88    gpu.convert_to_montgomery(&h_bases_buf, false);
89    gpu.convert_to_montgomery(&b_g2_bases_buf, true);
90
91    // Wait for all conversions to complete
92    #[cfg(not(target_family = "wasm"))]
93    let _ = gpu.device.poll(wgpu::PollType::wait_indefinitely());
94
95    GpuProvingKey {
96        h_len: ppk.h_len,
97        alpha_g1: ppk.alpha_g1,
98        beta_g1: ppk.beta_g1,
99        beta_g2: ppk.beta_g2,
100        delta_g1: ppk.delta_g1,
101        delta_g2: ppk.delta_g2,
102        a_bases_buf,
103        b_g1_bases_buf,
104        l_bases_buf,
105        h_bases_buf,
106        b_g2_bases_buf,
107    }
108}