webgpu_groth16/prover/
gpu_key.rs1use super::prepared_key::{PreparedProvingKey, interleave_glv_bases};
8use crate::gpu::GpuContext;
9use crate::gpu::curve::GpuCurve;
10
11pub struct GpuProvingKey<G: GpuCurve> {
15 pub(crate) h_len: usize,
16 pub(crate) alpha_g1: G::G1Affine,
17 pub(crate) beta_g1: G::G1Affine,
18 pub(crate) beta_g2: G::G2Affine,
19 pub(crate) delta_g1: G::G1Affine,
20 pub(crate) delta_g2: G::G2Affine,
21 pub(crate) a_bases_buf: wgpu::Buffer,
22 pub(crate) b_g1_bases_buf: wgpu::Buffer,
23 pub(crate) l_bases_buf: wgpu::Buffer,
24 pub(crate) h_bases_buf: wgpu::Buffer,
25 pub(crate) b_g2_bases_buf: wgpu::Buffer,
26}
27
28pub fn prepare_gpu_proving_key<G: GpuCurve>(
31 ppk: &PreparedProvingKey<G>,
32 gpu: &GpuContext<G>,
33) -> GpuProvingKey<G> {
34 let a_combined;
35 let b_g1_combined;
36 let l_combined;
37 let h_combined;
38 if G::HAS_G1_GLV {
39 a_combined = interleave_glv_bases(
41 &ppk.a_bytes,
42 ppk.a_phi_bytes
43 .as_deref()
44 .expect("HAS_G1_GLV requires a_phi_bytes"),
45 G::G1_GPU_BYTES,
46 );
47 b_g1_combined = interleave_glv_bases(
48 &ppk.b_g1_bytes,
49 ppk.b_g1_phi_bytes
50 .as_deref()
51 .expect("HAS_G1_GLV requires b_g1_phi_bytes"),
52 G::G1_GPU_BYTES,
53 );
54 l_combined = interleave_glv_bases(
55 &ppk.l_bytes,
56 ppk.l_phi_bytes
57 .as_deref()
58 .expect("HAS_G1_GLV requires l_phi_bytes"),
59 G::G1_GPU_BYTES,
60 );
61 h_combined = interleave_glv_bases(
62 &ppk.h_bytes,
63 ppk.h_phi_bytes
64 .as_deref()
65 .expect("HAS_G1_GLV requires h_phi_bytes"),
66 G::G1_GPU_BYTES,
67 );
68 } else {
69 a_combined = ppk.a_bytes.clone();
70 b_g1_combined = ppk.b_g1_bytes.clone();
71 l_combined = ppk.l_bytes.clone();
72 h_combined = ppk.h_bytes.clone();
73 }
74
75 let a_bases_buf = gpu.create_storage_buffer("gpk_a_bases", &a_combined);
77 let b_g1_bases_buf =
78 gpu.create_storage_buffer("gpk_b1_bases", &b_g1_combined);
79 let l_bases_buf = gpu.create_storage_buffer("gpk_l_bases", &l_combined);
80 let h_bases_buf = gpu.create_storage_buffer("gpk_h_bases", &h_combined);
81 let b_g2_bases_buf =
82 gpu.create_storage_buffer("gpk_b2_bases", &ppk.b_g2_bytes);
83
84 gpu.convert_to_montgomery(&a_bases_buf, false);
86 gpu.convert_to_montgomery(&b_g1_bases_buf, false);
87 gpu.convert_to_montgomery(&l_bases_buf, false);
88 gpu.convert_to_montgomery(&h_bases_buf, false);
89 gpu.convert_to_montgomery(&b_g2_bases_buf, true);
90
91 #[cfg(not(target_family = "wasm"))]
93 let _ = gpu.device.poll(wgpu::PollType::wait_indefinitely());
94
95 GpuProvingKey {
96 h_len: ppk.h_len,
97 alpha_g1: ppk.alpha_g1,
98 beta_g1: ppk.beta_g1,
99 beta_g2: ppk.beta_g2,
100 delta_g1: ppk.delta_g1,
101 delta_g2: ppk.delta_g2,
102 a_bases_buf,
103 b_g1_bases_buf,
104 l_bases_buf,
105 h_bases_buf,
106 b_g2_bases_buf,
107 }
108}