1use std::ops::Mul;
2
3use elliptic_curve::{
4 generic_array::{typenum::Unsigned, GenericArray},
5 group::cofactor::CofactorGroup,
6 ops::MulByGenerator,
7 sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint},
8 Curve,
9 CurveArithmetic,
10 Field,
11 ProjectivePoint,
12 Scalar,
13 ScalarPrimitive,
14};
15use sha2::{
16 digest::{crypto_common::BlockSizeUser, FixedOutput, FixedOutputReset},
17 Digest,
18};
19
20use crate::{
21 error::{Result, VrfError},
22 VrfStruct,
23};
24
25impl<C, D> VrfStruct<C, D>
26where
27 C: Curve,
28 C: CurveArithmetic,
29 C::FieldBytesSize: ModulusSize,
30{
31 pub const fn cofactor(&self) -> Scalar<C> {
33 <C as CurveArithmetic>::Scalar::ONE
35 }
36
37 const fn pt_len(&self) -> usize {
39 <C as Curve>::FieldBytesSize::USIZE
40 }
41
42 const fn c_len(&self) -> usize {
45 self.q_len() / 2
46 }
47
48 const fn q_len(&self) -> usize {
51 <C as Curve>::FieldBytesSize::USIZE
54 }
55}
56
57impl<C, D> VrfStruct<C, D>
58where
59 C: CurveArithmetic,
60 C::FieldBytesSize: ModulusSize,
61 C::AffinePoint: FromEncodedPoint<C>,
62 C::ProjectivePoint: ToEncodedPoint<C> + CofactorGroup,
63 D: Digest + BlockSizeUser + FixedOutput<OutputSize = C::FieldBytesSize> + FixedOutputReset,
64{
65 pub fn prove(&self, secret_key: &[u8], alpha: &[u8]) -> Result<Vec<u8>> {
77 let secret_key_scalar = self.scalar_from_bytes(secret_key)?;
79 let public_key_point = C::ProjectivePoint::mul_by_generator(&secret_key_scalar);
80
81 let public_key_bytes: Vec<u8> = public_key_point.to_encoded_point(true).as_bytes().to_vec();
82
83 let h_point = ProjectivePoint::<C>::from(self.encode_to_curve_tai(&public_key_bytes, alpha)?);
85
86 let h_point_bytes = h_point.to_encoded_point(true).as_bytes().to_vec();
88
89 let gamma_point = h_point.mul(secret_key_scalar);
91 let gamma_point_bytes = gamma_point.to_encoded_point(true).as_bytes().to_vec();
92
93 let k_scalar = self.scalar_from_bytes(&self.generate_nonce(secret_key, &h_point_bytes))?;
95
96 let u_point = C::ProjectivePoint::mul_by_generator(&k_scalar);
99 let u_point_bytes = u_point.to_encoded_point(true).as_bytes().to_vec();
100 let v_point = h_point * k_scalar;
102 let v_point_bytes = v_point.to_encoded_point(true).as_bytes().to_vec();
103 let c_scalar_bytes = self.challenge_generation(
105 &[
106 &public_key_bytes,
107 &h_point_bytes,
108 &gamma_point_bytes,
109 &u_point_bytes,
110 &v_point_bytes,
111 ],
112 self.c_len(),
113 )?;
114 let mut c_padded_bytes: Vec<u8> = vec![0; C::FieldBytesSize::USIZE - self.c_len()];
115 c_padded_bytes.extend_from_slice(&c_scalar_bytes);
116 let c_scalar = self.scalar_from_bytes(&c_padded_bytes)?;
117
118 let s_scalar = k_scalar + c_scalar * secret_key_scalar;
120 let s_scalar_bytes = Into::<ScalarPrimitive<C>>::into(s_scalar).to_bytes();
121
122 let proof = [&gamma_point_bytes[..], &c_scalar_bytes, &s_scalar_bytes].concat();
124
125 Ok(proof)
126 }
127
128 pub fn verify(&self, public_key: &[u8], pi: &[u8], alpha: &[u8]) -> Result<GenericArray<u8, C::FieldBytesSize>> {
141 let public_key_point = C::ProjectivePoint::from(self.point_from_bytes(public_key)?);
143
144 if public_key_point.is_small_order().into() {
147 return Err(VrfError::VerifyInvalidKey);
148 }
149
150 let (gamma_point_bytes, c_scalar_bytes, s_scalar_bytes) = self.decode_proof(pi)?;
152 let gamma_point = C::ProjectivePoint::from(self.point_from_bytes(&gamma_point_bytes)?);
153 let c_scalar = self.scalar_from_bytes(&c_scalar_bytes)?;
154 let s_scalar = self.scalar_from_bytes(&s_scalar_bytes)?;
155
156 let h_point = ProjectivePoint::<C>::from(self.encode_to_curve_tai(public_key, alpha)?);
158 let h_point_bytes = h_point.to_encoded_point(true).as_bytes().to_vec();
159
160 let u_point = C::ProjectivePoint::mul_by_generator(&s_scalar) - public_key_point * c_scalar;
162 let u_point_bytes = u_point.to_encoded_point(true).as_bytes().to_vec();
163
164 let v_point = h_point * s_scalar - gamma_point * c_scalar;
166 let v_point_bytes = v_point.to_encoded_point(true).as_bytes().to_vec();
167
168 let derived_c_bytes = self.challenge_generation(
170 &[
171 public_key,
172 &h_point_bytes,
173 &gamma_point_bytes,
174 &u_point_bytes,
175 &v_point_bytes,
176 ],
177 self.c_len(),
178 )?;
179 let mut padded_derived_c_bytes: Vec<u8> = vec![0; C::FieldBytesSize::USIZE - self.c_len()];
180 padded_derived_c_bytes.extend_from_slice(&derived_c_bytes);
181
182 if padded_derived_c_bytes != c_scalar_bytes {
184 return Err(VrfError::InvalidProof);
185 }
186
187 self.gamma_to_hash(&gamma_point)
189 }
190
191 pub fn proof_to_hash(&self, pi: &[u8]) -> Result<GenericArray<u8, C::FieldBytesSize>> {
202 let gamma_point_bytes = self.decode_proof(pi)?.0;
203 let gamma_point = C::ProjectivePoint::from(self.point_from_bytes(&gamma_point_bytes)?);
204
205 self.gamma_to_hash(&gamma_point)
206 }
207
208 pub(crate) fn gamma_to_hash(&self, gamma: &C::ProjectivePoint) -> Result<GenericArray<u8, C::FieldBytesSize>> {
219 const PROOF_TO_HASH_DOMAIN_SEPARATOR_FRONT: u8 = 0x03;
221
222 const PROOF_TO_HASH_DOMAIN_SEPARATOR_BACK: u8 = 0x00;
224
225 let point: ProjectivePoint<C> = gamma.mul(self.cofactor());
229 let point_bytes = point.to_encoded_point(true).as_bytes().to_vec();
230
231 Ok(D::digest(
232 [
233 &[self.suite_id],
234 &[PROOF_TO_HASH_DOMAIN_SEPARATOR_FRONT],
235 &point_bytes[..],
236 &[PROOF_TO_HASH_DOMAIN_SEPARATOR_BACK],
237 ]
238 .concat(),
239 ))
240 }
241
242 pub(crate) fn decode_proof(&self, pi: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
253 let gamma_oct = self.pt_len() + 1;
256 if pi.len() != gamma_oct + self.c_len() * 3 {
257 return Err(VrfError::InvalidPiLength);
258 }
259
260 let gamma = pi[0..gamma_oct].to_vec();
262
263 let mut c_scalar: Vec<u8> = vec![0; <C as Curve>::FieldBytesSize::USIZE - self.c_len()];
265 c_scalar.extend_from_slice(&pi[gamma_oct..gamma_oct + self.c_len()]);
266
267 let s_scalar = pi[gamma_oct + self.c_len()..].to_vec();
269
270 Ok((gamma, c_scalar, s_scalar))
271 }
272}