1use super::proof_of_time::{iterate_squarings, serialize};
16use classgroup::{gmp_classgroup::GmpClassGroup, BigNum, BigNumExt, ClassGroup};
17use sha2::{digest::FixedOutput, Digest, Sha256};
18use std::{cmp::Eq, collections::HashMap, hash::Hash, mem, u64, usize};
19
20#[derive(Debug, Clone)]
21pub struct WesolowskiVDF {
22 int_size_bits: u16,
23}
24use super::InvalidIterations as Bad;
25
26#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
27pub struct WesolowskiVDFParams(pub u16);
28impl super::VDFParams for WesolowskiVDFParams {
29 type VDF = WesolowskiVDF;
30 fn new(self) -> Self::VDF {
31 WesolowskiVDF {
32 int_size_bits: self.0,
33 }
34 }
35}
36
37impl super::VDF for WesolowskiVDF {
38 fn check_difficulty(&self, _difficulty: u64) -> Result<(), Bad> {
39 Ok(())
40 }
41 fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
42 if difficulty > usize::MAX as u64 {
43 Err(Bad("Cannot have more that usize::MAX iterations".to_owned()))
44 } else {
45 Ok(create_proof_of_time_wesolowski::<
46 <GmpClassGroup as ClassGroup>::BigNum,
47 GmpClassGroup,
48 >(
49 challenge, difficulty as usize, self.int_size_bits
50 ))
51 }
52 }
53
54 fn verify(
55 &self,
56 challenge: &[u8],
57 difficulty: u64,
58 alleged_solution: &[u8],
59 ) -> Result<(), super::InvalidProof> {
60 check_proof_of_time_wesolowski::<<GmpClassGroup as ClassGroup>::BigNum, GmpClassGroup>(
61 challenge,
62 alleged_solution,
63 difficulty,
64 self.int_size_bits,
65 )
66 .map_err(|()| super::InvalidProof)
67 }
68}
69pub fn approximate_parameters(t: f64) -> (usize, u8, u64) {
74 let log_memory = (10_000_000.0f64).log2();
75 let log_t = (t as f64).log2();
76 let l = if log_t - log_memory > 0. {
77 2.0f64.powf(log_memory - 20.).ceil()
78 } else {
79 1.
80 };
81
82 let intermediate = t * (2.0f64).ln() / (2.0 * l);
83 let k = (intermediate.ln() - intermediate.ln().ln() + 0.25)
84 .round()
85 .max(1.);
86
87 let w = (t / (t / k + l * (2.0f64).powf(k + 1.0)) - 2.0).floor();
88 (l as _, k as _, w as _)
89}
90
91fn u64_to_bytes(q: u64) -> [u8; 8] {
92 if false {
93 unsafe { std::mem::transmute(q.to_be()) }
95 } else {
96 [
97 (q >> 56) as u8,
98 (q >> 48) as u8,
99 (q >> 40) as u8,
100 (q >> 32) as u8,
101 (q >> 24) as u8,
102 (q >> 16) as u8,
103 (q >> 8) as u8,
104 q as u8,
105 ]
106 }
107}
108
109fn hash_prime<T: BigNum>(seed: &[&[u8]]) -> T {
113 let mut j = 0u64;
114 loop {
115 let mut hasher = Sha256::new();
116 hasher.input(b"prime");
117 hasher.input(u64_to_bytes(j));
118 for i in seed {
119 hasher.input(i);
120 }
121 let n = T::from(&hasher.fixed_result()[..16]);
122 if n.probab_prime(2) {
123 break n;
124 }
125 j += 1;
126 }
127}
128
129fn get_block<T: BigNumExt>(i: u64, k: u8, t: u64, b: &T) -> T {
134 let mut res = T::from(0);
135 let two = T::from(2);
136 res.mod_powm(&two, &T::from(t - u64::from(k) * (i + 1)), b);
137 res *= &((two >> 1) << (k as usize));
138 res / b
139}
140
141fn eval_optimized<T, U: BigNumExt, L: ClassGroup<BigNum = U> + Eq + Hash>(
142 h: &L,
143 b: &U,
144 t: usize,
145 k: u8,
146 l: usize,
147 powers: &T,
148) -> L
149where
150 T: for<'a> std::ops::Index<&'a u64, Output = L>,
151{
152 assert!(k > 0, "k cannot be zero");
153 assert!(l > 0, "l cannot be zero");
154 let kl = (k as usize)
155 .checked_mul(l)
156 .expect("computing k*l overflowed a u64");
157 assert!(kl <= u64::MAX as _);
158 assert!((kl as u64) < (1u64 << 53), "k*l overflowed an f64");
159 assert!((t as u64) < (1u64 << 53), "t overflows an f64");
160 assert!(
161 k < (mem::size_of::<usize>() << 3) as u8,
162 "k must be less than the number of bits in a usize"
163 );
164 let k1 = k >> 1;
165 let k0 = k - k1;
166 let mut x = h.identity();
167 let identity = h.identity();
168 let k_exp = 1usize << k;
169 let k0_exp = 1usize << k0;
170 let k1_exp = 1usize << k1;
171 for j in (0..l).rev() {
172 x.pow(U::from(k_exp as u64));
173 let mut ys: HashMap<U, L> = HashMap::new();
174 for b in 0..1usize << k {
175 ys.entry(U::from(b as u64))
176 .or_insert_with(|| identity.clone());
177 }
178 let end_of_loop = ((t as f64) / kl as f64).ceil() as usize;
179 assert!(end_of_loop == 0 || (end_of_loop as u64 - 1).checked_mul(l as u64).is_some());
180 for i in 0..end_of_loop {
181 if t < k as usize * (i * l + j + 1) {
182 continue;
183 }
184 let b = get_block((i as u64) * (l as u64), k, t as _, b);
185 *ys.get_mut(&b).unwrap() *= &powers[&((i * kl) as _)];
186 }
187
188 for b1 in 0..k1_exp {
189 let mut z = identity.clone();
190 for b0 in 0..k0_exp {
191 z *= &ys[&U::from((b1 * k0_exp + b0) as u64)]
192 }
193 z.pow(U::from((b1 as u64) * (k0_exp as u64)));
194 x *= &z;
195 }
196
197 for b0 in 0..k0_exp {
198 let mut z = identity.clone();
199 for b1 in 0..k1_exp {
200 z *= &ys[&U::from((b1 * k0_exp + b0) as u64)];
201 }
202 z.pow(U::from(b0 as u64));
203 x *= &z;
204 }
205 }
206 x
207}
208
209pub fn generate_proof<U, T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
210 x: &V,
211 iterations: u64,
212 k: u8,
213 l: usize,
214 powers: &U,
215 int_size_bits: usize,
216) -> V
217where
218 U: for<'a> std::ops::Index<&'a u64, Output = V>,
219{
220 let element_len = 2 * ((int_size_bits + 16) >> 4);
221 let mut x_buf = vec![0; element_len];
222 x.serialize(&mut x_buf[..])
223 .expect(super::INCORRECT_BUFFER_SIZE);
224 let mut y_buf = vec![0; element_len];
225 powers[&iterations]
226 .serialize(&mut y_buf[..])
227 .expect(super::INCORRECT_BUFFER_SIZE);
228 let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
229 eval_optimized(&x, &b, iterations as _, k, l, powers)
230}
231
232pub fn verify_proof<T: BigNum, V: ClassGroup<BigNum = T>>(
234 mut x: V,
235 y: &V,
236 mut proof: V,
237 t: u64,
238 int_size_bits: usize,
239) -> Result<(), ()> {
240 let element_len = 2 * ((int_size_bits + 16) >> 4);
241 let mut x_buf = vec![0; element_len];
242 x.serialize(&mut x_buf[..])
243 .expect(super::INCORRECT_BUFFER_SIZE);
244 let mut y_buf = vec![0; element_len];
245 y.serialize(&mut y_buf[..])
246 .expect(super::INCORRECT_BUFFER_SIZE);
247 let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
248 let mut r = T::from(0);
249 r.mod_powm(&T::from(2u64), &T::from(t), &b);
250 proof.pow(b);
251 x.pow(r);
252 proof *= &x;
253 if &proof == y {
254 Ok(())
255 } else {
256 Err(())
257 }
258}
259
260pub fn create_proof_of_time_wesolowski<T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
261 challenge: &[u8],
262 iterations: usize,
263 int_size_bits: u16,
264) -> Vec<u8>
265where
266 for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
267 for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
268{
269 let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
270 let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant);
271 assert!((iterations as u128) < (1u128 << 53));
272 let (l, k, _) = approximate_parameters(iterations as f64);
273 let q = l.checked_mul(k as _).expect("bug");
274 let powers = iterate_squarings(
275 x.clone(),
276 (0..=iterations / q + 1)
277 .map(|i| i * q)
278 .chain(Some(iterations))
279 .map(|x| x as _),
280 );
281 let proof = generate_proof(&x, iterations as _, k, l, &powers, int_size_bits.into());
282 serialize(&[proof], &powers[&(iterations as _)], int_size_bits.into())
283}
284
285pub fn check_proof_of_time_wesolowski<T: BigNum, V: ClassGroup<BigNum = T>>(
286 challenge: &[u8],
287 proof_blob: &[u8],
288 iterations: u64,
289 int_size_bits: u16,
290) -> Result<(), ()>
291where
292 T: BigNumExt,
293{
294 let discriminant: T = super::create_discriminant::create_discriminant(challenge, int_size_bits);
295 let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant.clone());
296 if (usize::MAX - 16) < int_size_bits.into() {
297 return Err(());
298 }
299 let int_size = (usize::from(int_size_bits) + 16) >> 4;
300 if int_size * 4 != proof_blob.len() {
301 return Err(());
302 }
303 let (result_bytes, proof_bytes) = proof_blob.split_at(2 * int_size);
304 let proof = ClassGroup::from_bytes(proof_bytes, discriminant.clone());
305 let y = ClassGroup::from_bytes(result_bytes, discriminant);
306
307 verify_proof(x, &y, proof, iterations, int_size_bits.into())
308}