1use crate::util::*;
2use bs58;
3use curve25519_dalek::constants::{
4 RISTRETTO_BASEPOINT_POINT as G, RISTRETTO_BASEPOINT_TABLE as GT,
5};
6use std::borrow::Borrow;
7use subtle::{ConditionallySelectable, ConstantTimeEq};
8
9#[derive(Clone)]
10pub struct PublicKey(pub(crate) [u8; 32], pub(crate) Point);
11#[derive(Clone)]
12pub struct SecretKey(Scalar, PublicKey);
13value_type!(pub, Value, 32, "value");
14value_type!(pub, Proof, 64, "proof");
15
16impl PublicKey {
17 fn from_bytes(bytes: &[u8; 32]) -> Option<Self> {
18 Some(PublicKey(*bytes, unpack(bytes)?))
19 }
20
21 fn offset(&self, input: &[u8]) -> Scalar {
22 hash_s!(&self.0, input)
23 }
24
25 pub fn is_vrf_valid(&self, input: &impl Borrow<[u8]>, value: &Value, proof: &Proof) -> bool {
26 self.is_valid(input.borrow(), value, proof)
27 }
28
29 #[allow(clippy::arithmetic_side_effects)]
32 fn is_valid(&self, input: &[u8], value: &Value, proof: &Proof) -> bool {
33 let p = unwrap_or_return_false!(unpack(&value.0));
34 let (r, c) = unwrap_or_return_false!(unpack(&proof.0));
35 hash_s!(
36 &self.0,
37 &value.0,
38 vmul2(r + c * self.offset(input), &G, c, &self.1),
39 vmul2(r, &p, c, &G)
40 ) == c
41 }
42}
43
44#[allow(clippy::arithmetic_side_effects)]
48fn basemul(s: Scalar) -> Point {
49 &s * &*GT
50}
51
52fn safe_invert(s: Scalar) -> Scalar {
53 Scalar::conditional_select(&s, &Scalar::ONE, s.ct_eq(&Scalar::ZERO)).invert()
54}
55
56impl SecretKey {
57 pub(crate) fn from_scalar(sk: Scalar) -> Self {
58 let pk = basemul(sk);
59 SecretKey(sk, PublicKey(pk.pack(), pk))
60 }
61
62 fn from_bytes(bytes: &[u8; 32]) -> Option<Self> {
63 Some(Self::from_scalar(unpack(bytes)?))
64 }
65
66 pub fn public_key(&self) -> &PublicKey {
67 &self.1
68 }
69
70 pub fn compute_vrf(&self, input: &impl Borrow<[u8]>) -> Value {
71 self.compute(input.borrow())
72 }
73
74 #[allow(clippy::arithmetic_side_effects)]
78 fn compute(&self, input: &[u8]) -> Value {
79 Value(basemul(safe_invert(self.0 + self.1.offset(input))).pack())
80 }
81
82 pub fn compute_vrf_with_proof(&self, input: &impl Borrow<[u8]>) -> (Value, Proof) {
83 self.compute_with_proof(input.borrow())
84 }
85
86 #[allow(clippy::arithmetic_side_effects)]
90 fn compute_with_proof(&self, input: &[u8]) -> (Value, Proof) {
91 let x = self.0 + self.1.offset(input);
92 let inv = safe_invert(x);
93 let val = basemul(inv).pack();
94 let k = prs!(x);
95 let c = hash_s!(&(self.1).0, &val, basemul(k), basemul(inv * k));
96 (Value(val), Proof((k - c * x, c).pack()))
97 }
98
99 pub fn is_vrf_valid(&self, input: &impl Borrow<[u8]>, value: &Value, proof: &Proof) -> bool {
100 self.1.is_valid(input.borrow(), value, proof)
101 }
102}
103
104macro_rules! traits {
105 ($ty:ident, $l:literal, $bytes:expr, $what:literal) => {
106 eq!($ty, |a, b| a.0 == b.0);
107 common_conversions_fixed!($ty, 32, $bytes, $what);
108
109 impl TryFrom<&[u8; $l]> for $ty {
110 type Error = ();
111 fn try_from(value: &[u8; $l]) -> Result<Self, ()> {
112 Self::from_bytes(value).ok_or(())
113 }
114 }
115 };
116}
117
118traits!(PublicKey, 32, |s| &s.0, "public key");
119traits!(SecretKey, 32, |s| s.0.as_bytes(), "secret key");
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 use secp256k1::rand::rngs::OsRng;
126 use serde::{Deserialize, Serialize};
127 use serde_json::{from_str, to_string};
128
129 fn random_secret_key() -> SecretKey {
130 SecretKey::from_scalar(Scalar::random(&mut OsRng))
131 }
132
133 #[test]
134 fn test_conversion() {
135 let sk = random_secret_key();
136 let sk2 = SecretKey::from_bytes(&sk.clone().into()).unwrap();
137 assert_eq!(sk, sk2);
138 let pk = sk.public_key();
139 let pk2 = sk2.public_key();
140 let pk3 = PublicKey::from_bytes(&pk2.into()).unwrap();
141 assert_eq!(pk, pk2);
142 assert_eq!(pk.clone(), pk3);
143 }
144
145 #[test]
146 fn test_verify() {
147 let sk = random_secret_key();
148 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
149 let val2 = sk.compute_vrf(b"Test");
150 assert_eq!(val, val2);
151 assert!(sk.public_key().is_vrf_valid(b"Test", &val, &proof));
152 assert!(!sk.public_key().is_vrf_valid(b"Tent", &val, &proof));
153 }
154
155 #[test]
156 fn test_different_keys() {
157 let sk = random_secret_key();
158 let sk2 = random_secret_key();
159 assert_ne!(sk, sk2);
160 assert_ne!(Into::<[u8; 32]>::into(sk.clone()), Into::<[u8; 32]>::into(sk2.clone()));
161 let pk = sk.public_key();
162 let pk2 = sk2.public_key();
163 assert_ne!(pk, pk2);
164 assert_ne!(Into::<[u8; 32]>::into(pk), Into::<[u8; 32]>::into(pk2));
165 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
166 let (val2, proof2) = sk2.compute_vrf_with_proof(b"Test");
167 assert_ne!(val, val2);
168 assert_ne!(proof, proof2);
169 assert!(!pk2.is_vrf_valid(b"Test", &val, &proof));
170 assert!(!pk2.is_vrf_valid(b"Test", &val2, &proof));
171 assert!(!pk2.is_vrf_valid(b"Test", &val, &proof2));
172 }
173
174 fn round_trip<T: Serialize + for<'de> Deserialize<'de>>(value: &T) -> T {
175 from_str(to_string(value).unwrap().as_str()).unwrap()
176 }
177
178 #[test]
179 fn test_serialize() {
180 let sk = random_secret_key();
181 let sk2 = round_trip(&sk);
182 assert_eq!(sk, sk2);
183 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
184 let (val2, proof2) = sk2.compute_vrf_with_proof(b"Test");
185 let (val3, proof3) = (round_trip(&val), round_trip(&proof));
186 assert_eq!((val, proof), (val2, proof2));
187 assert_eq!((val, proof), (val3, proof3));
188 let pk = sk.public_key();
189 let pk2 = sk2.public_key();
190 let pk3 = round_trip(pk);
191 assert!(pk.is_vrf_valid(b"Test", &val, &proof));
192 assert!(pk2.is_vrf_valid(b"Test", &val, &proof));
193 assert!(pk3.is_vrf_valid(b"Test", &val, &proof));
194 }
195}