1use super::proof_of_time::{iterate_squarings, serialize};
16use vdf_classgroup::{gmp_classgroup::GmpClassGroup, BigNum, BigNumExt, ClassGroup};
17use sha2::{Digest, Sha256};
18use std::{cmp::Eq, collections::HashMap, hash::Hash, mem, u64, usize};
19
20#[derive(Debug, Clone)]
21pub struct WesolowskiVDF {
22 int_size_bits: u16,
23}
24use super::InvalidIterations as Bad;
25
26#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
27pub struct WesolowskiVDFParams(pub u16);
28impl super::VDFParams for WesolowskiVDFParams {
29 type VDF = WesolowskiVDF;
30 fn new(self) -> Self::VDF {
31 WesolowskiVDF {
32 int_size_bits: self.0,
33 }
34 }
35}
36
37impl super::VDF for WesolowskiVDF {
38 fn check_difficulty(&self, _difficulty: u64) -> Result<(), Bad> {
39 Ok(())
40 }
41 fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
42 if difficulty > usize::MAX as u64 {
43 Err(Bad("Cannot have more that usize::MAX iterations".to_owned()))
44 } else {
45 Ok(create_proof_of_time_wesolowski::<
46 <GmpClassGroup as ClassGroup>::BigNum,
47 GmpClassGroup,
48 >(
49 challenge, difficulty as usize, self.int_size_bits
50 ))
51 }
52 }
53
54 fn verify(
55 &self,
56 challenge: &[u8],
57 difficulty: u64,
58 alleged_solution: &[u8],
59 ) -> Result<(), super::InvalidProof> {
60 check_proof_of_time_wesolowski::<<GmpClassGroup as ClassGroup>::BigNum, GmpClassGroup>(
61 challenge,
62 alleged_solution,
63 difficulty,
64 self.int_size_bits,
65 )
66 .map_err(|()| super::InvalidProof)
67 }
68}
69pub fn approximate_parameters(t: f64) -> (usize, u8, u64) {
74 let log_memory = (10_000_000.0f64).log2();
75 let log_t = (t as f64).log2();
76 let l = if log_t - log_memory > 0. {
77 2.0f64.powf(log_memory - 20.).ceil()
78 } else {
79 1.
80 };
81
82 let intermediate = t * (2.0f64).ln() / (2.0 * l);
83 let k = (intermediate.ln() - intermediate.ln().ln() + 0.25)
84 .round()
85 .max(1.);
86
87 let w = (t / (t / k + l * (2.0f64).powf(k + 1.0)) - 2.0).floor();
88 (l as _, k as _, w as _)
89}
90
91fn u64_to_bytes(q: u64) -> [u8; 8] {
92 q.to_be_bytes()
94}
95
96fn hash_prime<T: BigNum>(seed: &[&[u8]]) -> T {
100 let mut j = 0u64;
101 loop {
102 let mut hasher = Sha256::new();
103 hasher.update(b"prime");
104 hasher.update(u64_to_bytes(j));
105 for i in seed {
106 hasher.update(i);
107 }
108 let n = T::from(&hasher.finalize()[..16]);
109 if n.probab_prime(2) {
110 break n;
111 }
112 j += 1;
113 }
114}
115
116fn get_block<T: BigNumExt>(i: u64, k: u8, t: u64, b: &T) -> T {
121 let mut res = T::from(0);
122 let two = T::from(2);
123 res.mod_powm(&two, &T::from(t - u64::from(k) * (i + 1)), b);
124 res *= &((two >> 1) << (k as usize));
125 res / b
126}
127
128fn eval_optimized<T, U: BigNumExt, L: ClassGroup<BigNum = U> + Eq + Hash>(
129 h: &L,
130 b: &U,
131 t: usize,
132 k: u8,
133 l: usize,
134 powers: &T,
135) -> L
136where
137 T: for<'a> std::ops::Index<&'a u64, Output = L>,
138{
139 assert!(k > 0, "k cannot be zero");
140 assert!(l > 0, "l cannot be zero");
141 let kl = (k as usize)
142 .checked_mul(l)
143 .expect("computing k*l overflowed a u64");
144 assert!(kl <= u64::MAX as _);
145 assert!((kl as u64) < (1u64 << 53), "k*l overflowed an f64");
146 assert!((t as u64) < (1u64 << 53), "t overflows an f64");
147 assert!(
148 k < (mem::size_of::<usize>() << 3) as u8,
149 "k must be less than the number of bits in a usize"
150 );
151 let k1 = k >> 1;
152 let k0 = k - k1;
153 let mut x = h.identity();
154 let identity = h.identity();
155 let k_exp = 1usize << k;
156 let k0_exp = 1usize << k0;
157 let k1_exp = 1usize << k1;
158 for j in (0..l).rev() {
159 x.pow(U::from(k_exp as u64));
160 let mut ys: HashMap<U, L> = HashMap::new();
161 for b in 0..1usize << k {
162 ys.entry(U::from(b as u64))
163 .or_insert_with(|| identity.clone());
164 }
165 let end_of_loop = ((t as f64) / kl as f64).ceil() as usize;
166 assert!(end_of_loop == 0 || (end_of_loop as u64 - 1).checked_mul(l as u64).is_some());
167 for i in 0..end_of_loop {
168 if t < k as usize * (i * l + j + 1) {
169 continue;
170 }
171 let b = get_block((i as u64) * (l as u64), k, t as _, b);
172 *ys.get_mut(&b).unwrap() *= &powers[&((i * kl) as _)];
173 }
174
175 for b1 in 0..k1_exp {
176 let mut z = identity.clone();
177 for b0 in 0..k0_exp {
178 z *= &ys[&U::from((b1 * k0_exp + b0) as u64)]
179 }
180 z.pow(U::from((b1 as u64) * (k0_exp as u64)));
181 x *= &z;
182 }
183
184 for b0 in 0..k0_exp {
185 let mut z = identity.clone();
186 for b1 in 0..k1_exp {
187 z *= &ys[&U::from((b1 * k0_exp + b0) as u64)];
188 }
189 z.pow(U::from(b0 as u64));
190 x *= &z;
191 }
192 }
193 x
194}
195
196pub fn generate_proof<U, T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
197 x: &V,
198 iterations: u64,
199 k: u8,
200 l: usize,
201 powers: &U,
202 int_size_bits: usize,
203) -> V
204where
205 U: for<'a> std::ops::Index<&'a u64, Output = V>,
206{
207 let element_len = 2 * ((int_size_bits + 16) >> 4);
208 let mut x_buf = vec![0; element_len];
209 x.serialize(&mut x_buf[..])
210 .expect(super::INCORRECT_BUFFER_SIZE);
211 let mut y_buf = vec![0; element_len];
212 powers[&iterations]
213 .serialize(&mut y_buf[..])
214 .expect(super::INCORRECT_BUFFER_SIZE);
215 let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
216 eval_optimized(&x, &b, iterations as _, k, l, powers)
217}
218
219pub fn verify_proof<T: BigNum, V: ClassGroup<BigNum = T>>(
221 mut x: V,
222 y: &V,
223 mut proof: V,
224 t: u64,
225 int_size_bits: usize,
226) -> Result<(), ()> {
227 let element_len = 2 * ((int_size_bits + 16) >> 4);
228 let mut x_buf = vec![0; element_len];
229 x.serialize(&mut x_buf[..])
230 .expect(super::INCORRECT_BUFFER_SIZE);
231 let mut y_buf = vec![0; element_len];
232 y.serialize(&mut y_buf[..])
233 .expect(super::INCORRECT_BUFFER_SIZE);
234 let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
235 let mut r = T::from(0);
236 r.mod_powm(&T::from(2u64), &T::from(t), &b);
237 proof.pow(b);
238 x.pow(r);
239 proof *= &x;
240 if &proof == y {
241 Ok(())
242 } else {
243 Err(())
244 }
245}
246
247pub fn create_proof_of_time_wesolowski<T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
248 challenge: &[u8],
249 iterations: usize,
250 int_size_bits: u16,
251) -> Vec<u8>
252where
253 for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
254 for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
255{
256 let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
257 let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant);
258 assert!((iterations as u128) < (1u128 << 53));
259 let (l, k, _) = approximate_parameters(iterations as f64);
260 let q = l.checked_mul(k as _).expect("bug");
261 let powers = iterate_squarings(
262 x.clone(),
263 (0..=iterations / q + 1)
264 .map(|i| i * q)
265 .chain(Some(iterations))
266 .map(|x| x as _),
267 );
268 let proof = generate_proof(&x, iterations as _, k, l, &powers, int_size_bits.into());
269 serialize(&[proof], &powers[&(iterations as _)], int_size_bits.into())
270}
271
272pub fn check_proof_of_time_wesolowski<T: BigNum, V: ClassGroup<BigNum = T>>(
273 challenge: &[u8],
274 proof_blob: &[u8],
275 iterations: u64,
276 int_size_bits: u16,
277) -> Result<(), ()>
278where
279 T: BigNumExt,
280{
281 let discriminant: T = super::create_discriminant::create_discriminant(challenge, int_size_bits);
282 let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant.clone());
283 if (usize::MAX - 16) < int_size_bits.into() {
284 return Err(());
285 }
286 let int_size = (usize::from(int_size_bits) + 16) >> 4;
287 if int_size * 4 != proof_blob.len() {
288 return Err(());
289 }
290 let (result_bytes, proof_bytes) = proof_blob.split_at(2 * int_size);
291 let proof = ClassGroup::from_bytes(proof_bytes, discriminant.clone());
292 let y = ClassGroup::from_bytes(result_bytes, discriminant);
293
294 verify_proof(x, &y, proof, iterations, int_size_bits.into())
295}