Skip to main content

webgpu_groth16/prover/
prepared_key.rs

1//! Pre-serialized proving key for GPU dispatch.
2//!
3//! Converts proving key bases to GPU-friendly byte representations once,
4//! amortizing serialization cost across multiple proofs. When available,
5//! also stores GLV endomorphism bases φ(P) for G1 sets.
6
7use crate::bellman;
8use crate::gpu::curve::GpuCurve;
9
10/// Pre-serialized proving key bases for GPU. Avoids re-serialization per proof.
11///
12/// Includes GLV endomorphism bases φ(P) for G1 sets, pre-computed once to
13/// amortize the endomorphism cost across proofs.
14pub struct PreparedProvingKey<G: GpuCurve> {
15    pub h_len: usize,
16    pub alpha_g1: G::G1Affine,
17    pub beta_g1: G::G1Affine,
18    pub beta_g2: G::G2Affine,
19    pub delta_g1: G::G1Affine,
20    pub delta_g2: G::G2Affine,
21    pub a_bytes: Vec<u8>,
22    pub a_phi_bytes: Option<Vec<u8>>,
23    pub b_g1_bytes: Vec<u8>,
24    pub b_g1_phi_bytes: Option<Vec<u8>>,
25    pub l_bytes: Vec<u8>,
26    pub l_phi_bytes: Option<Vec<u8>>,
27    pub h_bytes: Vec<u8>,
28    pub h_phi_bytes: Option<Vec<u8>>,
29    pub b_g2_bytes: Vec<u8>,
30}
31
32pub(crate) fn serialize_g1_bases<G: GpuCurve>(
33    bases: &[G::G1Affine],
34) -> Vec<u8> {
35    let mut bytes = Vec::with_capacity(bases.len() * G::G1_GPU_BYTES);
36    for base in bases {
37        bytes.extend_from_slice(&G::serialize_g1(base));
38    }
39    bytes
40}
41
42pub(crate) fn serialize_g1_phi_bases<G: GpuCurve>(
43    bases: &[G::G1Affine],
44) -> Vec<u8> {
45    debug_assert!(G::HAS_G1_GLV);
46    let mut bytes = Vec::with_capacity(bases.len() * G::G1_GPU_BYTES);
47    for base in bases {
48        let base_bytes = G::serialize_g1(base);
49        let phi = G::g1_endomorphism_base_bytes(&base_bytes)
50            .expect("HAS_G1_GLV requires g1_endomorphism_base_bytes");
51        bytes.extend_from_slice(&phi);
52    }
53    bytes
54}
55
56pub(crate) fn serialize_g2_bases<G: GpuCurve>(
57    bases: &[G::G2Affine],
58) -> Vec<u8> {
59    let mut bytes = Vec::with_capacity(bases.len() * G::G2_GPU_BYTES);
60    for base in bases {
61        bytes.extend_from_slice(&G::serialize_g2(base));
62    }
63    bytes
64}
65
66/// Interleave base bytes and phi bytes into [P₀, φ(P₀), P₁, φ(P₁), ...] layout.
67pub(crate) fn interleave_glv_bases(
68    bases_bytes: &[u8],
69    phi_bytes: &[u8],
70    point_size: usize,
71) -> Vec<u8> {
72    let n = bases_bytes.len() / point_size;
73    debug_assert_eq!(bases_bytes.len(), n * point_size);
74    debug_assert_eq!(phi_bytes.len(), n * point_size);
75    let mut combined = Vec::with_capacity(n * 2 * point_size);
76    for i in 0..n {
77        let start = i * point_size;
78        combined.extend_from_slice(&bases_bytes[start..start + point_size]);
79        combined.extend_from_slice(&phi_bytes[start..start + point_size]);
80    }
81    combined
82}
83
84pub fn prepare_proving_key<E, G>(
85    pk: &bellman::groth16::Parameters<E>,
86) -> PreparedProvingKey<G>
87where
88    E: pairing::MultiMillerLoop,
89    G: GpuCurve<
90            Engine = E,
91            Scalar = E::Fr,
92            G1 = E::G1,
93            G2 = E::G2,
94            G1Affine = E::G1Affine,
95            G2Affine = E::G2Affine,
96        >,
97{
98    let a_phi = if G::HAS_G1_GLV {
99        Some(serialize_g1_phi_bases::<G>(&pk.a))
100    } else {
101        None
102    };
103    let b1_phi = if G::HAS_G1_GLV {
104        Some(serialize_g1_phi_bases::<G>(&pk.b_g1))
105    } else {
106        None
107    };
108    let l_phi = if G::HAS_G1_GLV {
109        Some(serialize_g1_phi_bases::<G>(&pk.l))
110    } else {
111        None
112    };
113    let h_phi = if G::HAS_G1_GLV {
114        Some(serialize_g1_phi_bases::<G>(&pk.h))
115    } else {
116        None
117    };
118
119    PreparedProvingKey {
120        h_len: pk.h.len(),
121        alpha_g1: pk.vk.alpha_g1,
122        beta_g1: pk.vk.beta_g1,
123        beta_g2: pk.vk.beta_g2,
124        delta_g1: pk.vk.delta_g1,
125        delta_g2: pk.vk.delta_g2,
126        a_bytes: serialize_g1_bases::<G>(&pk.a),
127        a_phi_bytes: a_phi,
128        b_g1_bytes: serialize_g1_bases::<G>(&pk.b_g1),
129        b_g1_phi_bytes: b1_phi,
130        l_bytes: serialize_g1_bases::<G>(&pk.l),
131        l_phi_bytes: l_phi,
132        h_bytes: serialize_g1_bases::<G>(&pk.h),
133        h_phi_bytes: h_phi,
134        b_g2_bytes: serialize_g2_bases::<G>(&pk.b_g2),
135    }
136}