vrf_rs/
ecvrf.rs

1use std::ops::Mul;
2
3use elliptic_curve::{
4    generic_array::{typenum::Unsigned, GenericArray},
5    group::cofactor::CofactorGroup,
6    ops::MulByGenerator,
7    sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint},
8    Curve,
9    CurveArithmetic,
10    Field,
11    ProjectivePoint,
12    Scalar,
13    ScalarPrimitive,
14};
15use sha2::{
16    digest::{crypto_common::BlockSizeUser, FixedOutput, FixedOutputReset},
17    Digest,
18};
19
20use crate::{
21    error::{Result, VrfError},
22    VrfStruct,
23};
24
25impl<C, D> VrfStruct<C, D>
26where
27    C: Curve,
28    C: CurveArithmetic,
29    C::FieldBytesSize: ModulusSize,
30{
31    /// Curve cofactor, i.e., number of points on EC divided by prime order of group G.
32    pub const fn cofactor(&self) -> Scalar<C> {
33        // TODO: Change me! Wrong assumption that all curves have cofactor 1
34        <C as CurveArithmetic>::Scalar::ONE
35    }
36
37    /// Length, in octets, of a point on E encoded as an octet string.
38    const fn pt_len(&self) -> usize {
39        <C as Curve>::FieldBytesSize::USIZE
40    }
41
42    /// Length, in octets, of a challenge value used by the VRF.
43    /// Note: in the typical case, cLen is qLen/2 or close to it.
44    const fn c_len(&self) -> usize {
45        self.q_len() / 2
46    }
47
48    /// Length, in octets, of the prime order of group G (subgroup of EC of large prime order),
49    /// i.e., the smallest integer such that `2^(8qLen) > q`.
50    const fn q_len(&self) -> usize {
51        // TODO: Change me! It should be:
52        // const Q_LEN: usize = <Self::Curve as Curve>::ORDER.bits() / 8;
53        <C as Curve>::FieldBytesSize::USIZE
54    }
55}
56
57impl<C, D> VrfStruct<C, D>
58where
59    C: CurveArithmetic,
60    C::FieldBytesSize: ModulusSize,
61    C::AffinePoint: FromEncodedPoint<C>,
62    C::ProjectivePoint: ToEncodedPoint<C> + CofactorGroup,
63    D: Digest + BlockSizeUser + FixedOutput<OutputSize = C::FieldBytesSize> + FixedOutputReset,
64{
65    /// Generates a VRF proof from a secret key and message.
66    /// Spec: `ECVRF_prove` function (section 5.1).
67    ///
68    /// # Arguments
69    ///
70    /// * `x` - A slice representing the secret key in octets.
71    /// * `alpha` - A slice representing the message in octets.
72    ///
73    /// # Returns
74    ///
75    /// * If successful, a vector of octets representing the proof of the VRF.
76    pub fn prove(&self, secret_key: &[u8], alpha: &[u8]) -> Result<Vec<u8>> {
77        // Step 1: derive public key from secret key as `Y = x * B`
78        let secret_key_scalar = self.scalar_from_bytes(secret_key)?;
79        let public_key_point = C::ProjectivePoint::mul_by_generator(&secret_key_scalar);
80
81        let public_key_bytes: Vec<u8> = public_key_point.to_encoded_point(true).as_bytes().to_vec();
82
83        // Step 2: Encode to curve (using TAI)
84        let h_point = ProjectivePoint::<C>::from(self.encode_to_curve_tai(&public_key_bytes, alpha)?);
85
86        // Step 3: point to string (or bytes)
87        let h_point_bytes = h_point.to_encoded_point(true).as_bytes().to_vec();
88
89        // Step 4: Gamma = x * H
90        let gamma_point = h_point.mul(secret_key_scalar);
91        let gamma_point_bytes = gamma_point.to_encoded_point(true).as_bytes().to_vec();
92
93        // Step 5: nonce (k generation)
94        let k_scalar = self.scalar_from_bytes(&self.generate_nonce(secret_key, &h_point_bytes))?;
95
96        // Step 6: c = ECVRF_challenge_generation (Y, H, Gamma, U, V)
97        // U = k*B = k*Generator
98        let u_point = C::ProjectivePoint::mul_by_generator(&k_scalar);
99        let u_point_bytes = u_point.to_encoded_point(true).as_bytes().to_vec();
100        // V = k*H
101        let v_point = h_point * k_scalar;
102        let v_point_bytes = v_point.to_encoded_point(true).as_bytes().to_vec();
103        // Challenge generation (returns hash output truncated by `cLen`)
104        let c_scalar_bytes = self.challenge_generation(
105            &[
106                &public_key_bytes,
107                &h_point_bytes,
108                &gamma_point_bytes,
109                &u_point_bytes,
110                &v_point_bytes,
111            ],
112            self.c_len(),
113        )?;
114        let mut c_padded_bytes: Vec<u8> = vec![0; C::FieldBytesSize::USIZE - self.c_len()];
115        c_padded_bytes.extend_from_slice(&c_scalar_bytes);
116        let c_scalar = self.scalar_from_bytes(&c_padded_bytes)?;
117
118        // Step 7: s = (k + c*x) mod q
119        let s_scalar = k_scalar + c_scalar * secret_key_scalar;
120        let s_scalar_bytes = Into::<ScalarPrimitive<C>>::into(s_scalar).to_bytes();
121
122        // Step 8: encode (gamma, c, s)
123        let proof = [&gamma_point_bytes[..], &c_scalar_bytes, &s_scalar_bytes].concat();
124
125        Ok(proof)
126    }
127
128    /// Verifies the provided VRF proof and computes the VRF hash output.
129    /// Spec: `ECVRF_verify` function (section 5.2).
130    ///
131    /// # Arguments
132    ///
133    /// * `y`     - A slice representing the public key in octets.
134    /// * `pi`    - A slice of octets representing the VRF proof.
135    /// * `alpha` - A slice containing the input data, to be hashed.
136    ///
137    /// # Returns
138    ///
139    /// * If successful, a vector of octets with the VRF hash output.
140    pub fn verify(&self, public_key: &[u8], pi: &[u8], alpha: &[u8]) -> Result<GenericArray<u8, C::FieldBytesSize>> {
141        // Step 1-2: Y = string_to_point(PK_string)
142        let public_key_point = C::ProjectivePoint::from(self.point_from_bytes(public_key)?);
143
144        // Step 3: If validate_key, run ECVRF_validate_key(Y) (Section 5.4.5)
145        // TODO: Check step 3 again
146        if public_key_point.is_small_order().into() {
147            return Err(VrfError::VerifyInvalidKey);
148        }
149
150        // Step 4-6: D = ECVRF_decode_proof(pi_string)
151        let (gamma_point_bytes, c_scalar_bytes, s_scalar_bytes) = self.decode_proof(pi)?;
152        let gamma_point = C::ProjectivePoint::from(self.point_from_bytes(&gamma_point_bytes)?);
153        let c_scalar = self.scalar_from_bytes(&c_scalar_bytes)?;
154        let s_scalar = self.scalar_from_bytes(&s_scalar_bytes)?;
155
156        // Step 7: H = ECVRF_encode_to_curve(encode_to_curve_salt, alpha_string)
157        let h_point = ProjectivePoint::<C>::from(self.encode_to_curve_tai(public_key, alpha)?);
158        let h_point_bytes = h_point.to_encoded_point(true).as_bytes().to_vec();
159
160        // Step 8: U = s*B - c*Y
161        let u_point = C::ProjectivePoint::mul_by_generator(&s_scalar) - public_key_point * c_scalar;
162        let u_point_bytes = u_point.to_encoded_point(true).as_bytes().to_vec();
163
164        // Step 9: V = s*H - c*Gamma
165        let v_point = h_point * s_scalar - gamma_point * c_scalar;
166        let v_point_bytes = v_point.to_encoded_point(true).as_bytes().to_vec();
167
168        // Step 10: c' = ECVRF_challenge_generation(Y, H, Gamma, U, V)
169        let derived_c_bytes = self.challenge_generation(
170            &[
171                public_key,
172                &h_point_bytes,
173                &gamma_point_bytes,
174                &u_point_bytes,
175                &v_point_bytes,
176            ],
177            self.c_len(),
178        )?;
179        let mut padded_derived_c_bytes: Vec<u8> = vec![0; C::FieldBytesSize::USIZE - self.c_len()];
180        padded_derived_c_bytes.extend_from_slice(&derived_c_bytes);
181
182        // Step 11: Check if c and c' are equal
183        if padded_derived_c_bytes != c_scalar_bytes {
184            return Err(VrfError::InvalidProof);
185        }
186
187        // If valid VRF proof, ECVRF_proof_to_hash(pi_string)
188        self.gamma_to_hash(&gamma_point)
189    }
190
191    /// Function to compute VRF hash output for a given proof.
192    /// Spec: `ECVRF_proof_to_hash` function (steps 4-to 7).
193    ///
194    /// # Arguments
195    ///
196    /// * `proof`  - A vector of octets representing the proof of the VRF
197    ///
198    /// # Returns
199    ///
200    /// * A vector of octets with the VRF hash output.
201    pub fn proof_to_hash(&self, pi: &[u8]) -> Result<GenericArray<u8, C::FieldBytesSize>> {
202        let gamma_point_bytes = self.decode_proof(pi)?.0;
203        let gamma_point = C::ProjectivePoint::from(self.point_from_bytes(&gamma_point_bytes)?);
204
205        self.gamma_to_hash(&gamma_point)
206    }
207
208    /// Function to compute VRF hash output for a given gamma point (part of the VRF proof).
209    /// Spec: `ECVRF_proof_to_hash` function (steps 4-to 7).
210    ///
211    /// # Arguments
212    ///
213    /// * `gamma`  - An EC point representing the VRF gamma.
214    ///
215    /// # Returns
216    ///
217    /// * A vector of octets with the VRF hash output.
218    pub(crate) fn gamma_to_hash(&self, gamma: &C::ProjectivePoint) -> Result<GenericArray<u8, C::FieldBytesSize>> {
219        // Step 4: proof_to_hash_domain_separator_front = 0x03
220        const PROOF_TO_HASH_DOMAIN_SEPARATOR_FRONT: u8 = 0x03;
221
222        // Step 5: proof_to_hash_domain_separator_back = 0x00
223        const PROOF_TO_HASH_DOMAIN_SEPARATOR_BACK: u8 = 0x00;
224
225        // Step 6: Compute beta
226        // beta_string = Hash(suite_string || proof_to_hash_domain_separator_front ||
227        //                    point_to_string(cofactor * Gamma) || proof_to_hash_domain_separator_back)
228        let point: ProjectivePoint<C> = gamma.mul(self.cofactor());
229        let point_bytes = point.to_encoded_point(true).as_bytes().to_vec();
230
231        Ok(D::digest(
232            [
233                &[self.suite_id],
234                &[PROOF_TO_HASH_DOMAIN_SEPARATOR_FRONT],
235                &point_bytes[..],
236                &[PROOF_TO_HASH_DOMAIN_SEPARATOR_BACK],
237            ]
238            .concat(),
239        ))
240    }
241
242    /// Decodes a VRF proof by extracting the gamma EC point, and parameters `c` and `s` as bytes.
243    /// Spec: `ECVRF_decode_proof` function in section 5.4.4.
244    ///
245    /// # Arguments
246    ///
247    /// * `pi`  - A slice of octets representing the VRF proof
248    ///
249    /// # Returns
250    ///
251    /// * A tuple containing `gamma` point, and parameters `c` and `s`.
252    pub(crate) fn decode_proof(&self, pi: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
253        // Expected size of proof: len(pi) = len(gamma) + len(c) + len(s)
254        // len(s) = 2 * len(c), so len(pi) = len(gamma) + 3 * len(c)
255        let gamma_oct = self.pt_len() + 1;
256        if pi.len() != gamma_oct + self.c_len() * 3 {
257            return Err(VrfError::InvalidPiLength);
258        }
259
260        // Gamma point
261        let gamma = pi[0..gamma_oct].to_vec();
262
263        // C scalar (needs to be padded with leading zeroes)
264        let mut c_scalar: Vec<u8> = vec![0; <C as Curve>::FieldBytesSize::USIZE - self.c_len()];
265        c_scalar.extend_from_slice(&pi[gamma_oct..gamma_oct + self.c_len()]);
266
267        // S scalar
268        let s_scalar = pi[gamma_oct + self.c_len()..].to_vec();
269
270        Ok((gamma, c_scalar, s_scalar))
271    }
272}