1use crate::error::FastCryptoError;
5use crate::traits::AllowedRng;
6
7pub trait VRFPublicKey {
9 type PrivateKey: VRFPrivateKey<PublicKey = Self>;
10}
11
12pub trait VRFPrivateKey {
14 type PublicKey: VRFPublicKey<PrivateKey = Self>;
15}
16
17pub trait VRFKeyPair<const OUTPUT_SIZE: usize> {
19 type Proof: VRFProof<OUTPUT_SIZE, PublicKey = Self::PublicKey>;
20 type PrivateKey: VRFPrivateKey<PublicKey = Self::PublicKey>;
21 type PublicKey: VRFPublicKey<PrivateKey = Self::PrivateKey>;
22
23 fn generate<R: AllowedRng>(rng: &mut R) -> Self;
25
26 fn prove(&self, input: &[u8]) -> Self::Proof;
28
29 fn output(&self, input: &[u8]) -> ([u8; OUTPUT_SIZE], Self::Proof) {
31 let proof = self.prove(input);
32 let output = proof.to_hash();
33 (output, proof)
34 }
35}
36
37pub trait VRFProof<const OUTPUT_SIZE: usize> {
39 type PublicKey: VRFPublicKey;
40
41 fn verify(&self, input: &[u8], public_key: &Self::PublicKey) -> Result<(), FastCryptoError>;
43
44 fn verify_output(
46 &self,
47 input: &[u8],
48 public_key: &Self::PublicKey,
49 output: &[u8; OUTPUT_SIZE],
50 ) -> Result<(), FastCryptoError> {
51 self.verify(input, public_key)?;
52 if &self.to_hash() != output {
53 return Err(FastCryptoError::GeneralOpaqueError);
54 }
55 Ok(())
56 }
57
58 fn to_hash(&self) -> [u8; OUTPUT_SIZE];
60}
61
62pub mod ecvrf {
66 use crate::error::FastCryptoError;
67 use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
68 use crate::groups::{GroupElement, MultiScalarMul, Scalar};
69 use crate::hash::{HashFunction, ReverseWrapper, Sha512};
70 use crate::serde_helpers::ToFromByteArray;
71 use crate::traits::AllowedRng;
72 use crate::vrf::{VRFKeyPair, VRFPrivateKey, VRFProof, VRFPublicKey};
73 use elliptic_curve::hash2curve::{ExpandMsg, Expander};
74 use serde::{Deserialize, Serialize};
75 use zeroize::ZeroizeOnDrop;
76
77 const SUITE_STRING: &[u8; 7] = b"sui_vrf";
80
81 const C_LEN: usize = 16;
84
85 type H = Sha512;
87
88 const DST: &[u8; 49] = b"ECVRF_ristretto255_XMD:SHA-512_R255MAP_RO_sui_vrf";
90
91 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
92 pub struct ECVRFPublicKey(RistrettoPoint);
93
94 impl VRFPublicKey for ECVRFPublicKey {
95 type PrivateKey = ECVRFPrivateKey;
96 }
97
98 impl ECVRFPublicKey {
99 fn ecvrf_encode_to_curve(&self, alpha_string: &[u8]) -> RistrettoPoint {
101 let mut expanded_message = elliptic_curve::hash2curve::ExpandMsgXmd::<
109 <H as ReverseWrapper>::Variant,
110 >::expand_message(
111 &[&self.0.compress(), alpha_string],
112 &[DST],
113 H::OUTPUT_SIZE,
114 )
115 .unwrap();
116
117 let mut bytes = [0u8; H::OUTPUT_SIZE];
118 expanded_message.fill_bytes(&mut bytes);
119 RistrettoPoint::from_uniform_bytes(&bytes)
120 }
121
122 fn valid(&self) -> bool {
125 self.0 != RistrettoPoint::zero()
126 }
127 }
128
129 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, ZeroizeOnDrop)]
130 pub struct ECVRFPrivateKey(RistrettoScalar);
131
132 impl VRFPrivateKey for ECVRFPrivateKey {
133 type PublicKey = ECVRFPublicKey;
134 }
135
136 impl ECVRFPrivateKey {
137 fn ecvrf_nonce_generation(&self, h_string: &[u8]) -> RistrettoScalar {
139 let hashed_sk_string = H::digest(self.0.to_byte_array());
140 let mut truncated_hashed_sk_string = [0u8; 32];
141 truncated_hashed_sk_string.copy_from_slice(&hashed_sk_string.digest[32..64]);
142
143 let mut hash_function = H::default();
144 hash_function.update(truncated_hashed_sk_string);
145 hash_function.update(h_string);
146 let k_string = hash_function.finalize();
147
148 RistrettoScalar::from_bytes_mod_order_wide(&k_string.digest)
149 }
150 }
151
152 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
153 pub struct ECVRFKeyPair {
154 pub pk: ECVRFPublicKey,
155 pub sk: ECVRFPrivateKey,
156 }
157
158 impl ZeroizeOnDrop for ECVRFKeyPair {}
159
160 fn ecvrf_challenge_generation(points: [&RistrettoPoint; 5]) -> Challenge {
162 let mut hash = H::default();
163 hash.update(SUITE_STRING);
164 hash.update([0x02]); points.into_iter().for_each(|p| hash.update(p.compress()));
166 hash.update([0x00]); let digest = hash.finalize();
168
169 let mut challenge_bytes = [0u8; C_LEN];
170 challenge_bytes.copy_from_slice(&digest.digest[..C_LEN]);
171 Challenge(challenge_bytes)
172 }
173
174 impl ECVRFKeyPair {
175 pub fn public_key(&self) -> RistrettoPoint {
176 self.pk.0
177 }
178 }
179
180 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
182 pub struct Challenge(pub [u8; C_LEN]);
183
184 impl From<&Challenge> for RistrettoScalar {
185 fn from(c: &Challenge) -> Self {
186 let mut scalar = [0u8; 32];
187 scalar[..C_LEN].copy_from_slice(&c.0);
188 RistrettoScalar::from_bytes_mod_order(&scalar)
189 }
190 }
191
192 impl VRFKeyPair<64> for ECVRFKeyPair {
193 type Proof = ECVRFProof;
194 type PrivateKey = ECVRFPrivateKey;
195 type PublicKey = ECVRFPublicKey;
196
197 fn generate<R: AllowedRng>(rng: &mut R) -> Self {
198 let s = RistrettoScalar::rand(rng);
199 ECVRFKeyPair::from(ECVRFPrivateKey(s))
200 }
201
202 fn prove(&self, alpha_string: &[u8]) -> ECVRFProof {
203 let h = self.pk.ecvrf_encode_to_curve(alpha_string);
206 let h_string = h.compress();
207 let gamma = h * self.sk.0;
208 let k = self.sk.ecvrf_nonce_generation(&h_string);
209
210 let c = ecvrf_challenge_generation([
211 &self.pk.0,
212 &h,
213 &gamma,
214 &(RistrettoPoint::generator() * k),
215 &(h * k),
216 ]);
217 let s = k + RistrettoScalar::from(&c) * self.sk.0;
218
219 ECVRFProof { gamma, c, s }
220 }
221 }
222
223 impl From<ECVRFPrivateKey> for ECVRFKeyPair {
224 fn from(sk: ECVRFPrivateKey) -> Self {
225 let p = RistrettoPoint::generator() * sk.0;
226 ECVRFKeyPair {
227 pk: ECVRFPublicKey(p),
228 sk,
229 }
230 }
231 }
232
233 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
234 pub struct ECVRFProof {
235 pub gamma: RistrettoPoint,
236 pub c: Challenge,
237 pub s: RistrettoScalar,
238 }
239
240 impl VRFProof<64> for ECVRFProof {
241 type PublicKey = ECVRFPublicKey;
242
243 fn verify(
244 &self,
245 alpha_string: &[u8],
246 public_key: &Self::PublicKey,
247 ) -> Result<(), FastCryptoError> {
248 if !public_key.valid() {
251 return Err(FastCryptoError::InvalidInput);
252 }
253
254 let h = public_key.ecvrf_encode_to_curve(alpha_string);
255
256 let challenge = RistrettoScalar::from(&self.c);
257 let u = RistrettoPoint::multi_scalar_mul(
258 &[self.s, -challenge],
259 &[RistrettoPoint::generator(), public_key.0],
260 )?;
261 let v = RistrettoPoint::multi_scalar_mul(&[self.s, -challenge], &[h, self.gamma])?;
262
263 let c_prime = ecvrf_challenge_generation([&public_key.0, &h, &self.gamma, &u, &v]);
264
265 if c_prime != self.c {
266 return Err(FastCryptoError::GeneralOpaqueError);
267 }
268 Ok(())
269 }
270
271 fn to_hash(&self) -> [u8; 64] {
272 let mut hash = H::default();
274 hash.update(SUITE_STRING);
275 hash.update([0x03]); hash.update(self.gamma.compress());
277 hash.update([0x00]); hash.finalize().digest
279 }
280 }
281
282 impl ECVRFProof {
283 pub fn gamma_bytes(&self) -> [u8; 32] {
285 self.gamma.compress()
286 }
287
288 pub fn challenge_bytes(&self) -> [u8; C_LEN] {
290 self.c.0
291 }
292
293 pub fn scalar_bytes(&self) -> [u8; 32] {
295 self.s.to_byte_array()
296 }
297
298 pub fn from_components(
308 gamma_bytes: &[u8; 32],
309 challenge_bytes: &[u8; C_LEN],
310 scalar_bytes: &[u8; 32],
311 ) -> Result<Self, FastCryptoError> {
312 let gamma = RistrettoPoint::try_from(gamma_bytes.as_slice())?;
314
315 let c = Challenge(*challenge_bytes);
317
318 let s = RistrettoScalar::from_byte_array(scalar_bytes)?;
320
321 Ok(ECVRFProof { gamma, c, s })
322 }
323
324 pub fn to_components(&self) -> ([u8; 32], [u8; C_LEN], [u8; 32]) {
329 (self.gamma_bytes(), self.challenge_bytes(), self.scalar_bytes())
330 }
331 }
332}