1use super::proof_of_time::{deserialize_proof, iterate_squarings, serialize};
15use vdf_classgroup::{gmp_classgroup::GmpClassGroup, BigNumExt, ClassGroup};
16use num_traits::{One, Zero};
17use std::{fmt, num::ParseIntError, ops::Index, str::FromStr, u64, usize};
18
19#[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone, Debug)]
20pub struct Iterations(u64);
21
22#[derive(PartialEq, Eq, Hash, Ord, PartialOrd, Copy, Clone, Debug)]
23pub enum InvalidIterations {
24 OddNumber(u64),
25 LessThan66(u64),
26}
27
28#[derive(PartialEq, Eq, Clone, Debug)]
29pub struct ParseIterationsError {
30 kind: Result<InvalidIterations, ParseIntError>,
31}
32
33impl From<InvalidIterations> for ParseIterationsError {
34 fn from(t: InvalidIterations) -> Self {
35 Self { kind: Ok(t) }
36 }
37}
38
39impl From<ParseIntError> for ParseIterationsError {
40 fn from(t: ParseIntError) -> Self {
41 Self { kind: Err(t) }
42 }
43}
44
45impl fmt::Display for InvalidIterations {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match *self {
48 InvalidIterations::OddNumber(s) => {
49 write!(f, "Pietrzak iterations must be an even number, not {}", s)
50 }
51 InvalidIterations::LessThan66(s) => write!(
52 f,
53 "Pietrzak proof-of-time must run for at least 66 iterations, not {}",
54 s
55 ),
56 }
57 }
58}
59
60impl From<Iterations> for u64 {
61 fn from(t: Iterations) -> u64 {
62 t.0
63 }
64}
65
66impl fmt::Display for ParseIterationsError {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match self.kind {
69 Ok(ref q) => <InvalidIterations as fmt::Display>::fmt(q, f),
70 Err(ref q) => <ParseIntError as fmt::Display>::fmt(q, f),
71 }
72 }
73}
74
75impl FromStr for Iterations {
76 type Err = ParseIterationsError;
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 Self::new(s.parse::<u64>().map_err(ParseIterationsError::from)?)
79 .map_err(ParseIterationsError::from)
80 }
81}
82
83impl Iterations {
84 pub fn new<T: Into<u64>>(iterations: T) -> Result<Iterations, InvalidIterations> {
85 let iterations = iterations.into();
86 if iterations & 1 != 0 {
87 Err(InvalidIterations::OddNumber(iterations))
88 } else if iterations < 66 {
89 Err(InvalidIterations::LessThan66(iterations))
90 } else {
91 Ok(Iterations(iterations))
92 }
93 }
94}
95
96fn approximate_i(t: Iterations) -> u64 {
98 let x: f64 = (((t.0 >> 1) as f64) / 8.) * 2.0f64.ln();
99 let w = x.ln() - x.ln().ln() + 0.25;
100 (w / (2. * 2.0f64.ln())).round() as _
101}
102
103fn sum_combinations<'a, T: IntoIterator<Item = &'a u64>>(numbers: T) -> Vec<u64> {
104 let mut combinations = vec![0];
105 for i in numbers {
106 let mut new_combinations = combinations.clone();
107 for j in combinations {
108 new_combinations.push(i + j)
109 }
110 combinations = new_combinations
111 }
112 combinations.remove(0);
113 combinations
114}
115
116fn cache_indices_for_count(t: Iterations) -> Vec<u64> {
117 let i: u64 = approximate_i(t);
118 let mut curr_t = t.0;
119 let mut intermediate_ts = vec![];
120 for _ in 0..i {
121 curr_t >>= 1;
122 intermediate_ts.push(curr_t);
123 if curr_t & 1 != 0 {
124 curr_t += 1
125 }
126 }
127 let mut cache_indices = sum_combinations(&intermediate_ts);
128 cache_indices.sort();
129 cache_indices.push(t.0);
130 cache_indices
131}
132
133fn generate_r_value<T>(x: &T, y: &T, sqrt_mu: &T, int_size_bits: usize) -> T::BigNum
134where
135 T: ClassGroup,
136 for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
137 for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
138{
139 use sha2::{Digest, Sha256};
140
141 let size = (int_size_bits + 16) >> 4;
142 let mut v = Vec::with_capacity(size * 2);
143 for _ in 0..size * 2 {
144 v.push(0)
145 }
146 let mut hasher = Sha256::new();
147 for i in &[&x, &y, &sqrt_mu] {
148 i.serialize(&mut v).expect(super::INCORRECT_BUFFER_SIZE);
149 hasher.update(&v);
150 }
151 let res = hasher.finalize();
152 T::unsigned_deserialize_bignum(&res[..16])
153}
154
155fn create_proof_of_time_pietrzak<T>(
156 challenge: &[u8],
157 iterations: Iterations,
158 int_size_bits: u16,
159) -> Vec<u8>
160where
161 T: ClassGroup,
162 <T as ClassGroup>::BigNum: BigNumExt,
163 for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
164 for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
165{
166 let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
167 let x = T::from_ab_discriminant(2.into(), 1.into(), discriminant);
168
169 let delta = 8;
170 let powers_to_calculate = cache_indices_for_count(iterations);
171 let powers = iterate_squarings(x.clone(), powers_to_calculate.iter().cloned());
172 let proof: Vec<T> = generate_proof(
173 x,
174 iterations,
175 delta,
176 &powers,
177 &generate_r_value,
178 usize::from(int_size_bits),
179 );
180 serialize(
181 &proof,
182 &powers[&iterations.into()],
183 usize::from(int_size_bits),
184 )
185}
186
187pub fn check_proof_of_time_pietrzak<T>(
188 challenge: &[u8],
189 proof_blob: &[u8],
190 iterations: u64,
191 length_in_bits: u16,
192) -> Result<(), super::InvalidProof>
193where
194 T: ClassGroup,
195 T::BigNum: BigNumExt,
196 for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
197 for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
198{
199 let discriminant = super::create_discriminant::create_discriminant(&challenge, length_in_bits);
200 let x = T::from_ab_discriminant(2.into(), 1.into(), discriminant);
201 let iterations = Iterations::new(iterations).map_err(|_| super::InvalidProof)?;
202 if usize::MAX - 16 < length_in_bits.into() {
203 return Err(super::InvalidProof);
205 }
206 let length: usize = (usize::from(length_in_bits) + 16usize) >> 4;
207 if proof_blob.len() < 2 * length {
208 return Err(super::InvalidProof);
210 }
211 let result_bytes = &proof_blob[..length * 2];
212 let proof_bytes = &proof_blob[length * 2..];
213 let discriminant = x.discriminant().clone();
214 let proof =
215 deserialize_proof(proof_bytes, &discriminant, length).map_err(|()| super::InvalidProof)?;
216 let y = T::from_bytes(result_bytes, discriminant);
217 verify_proof(
218 &x,
219 &y,
220 proof,
221 iterations,
222 8,
223 &generate_r_value,
224 length_in_bits.into(),
225 )
226 .map_err(|()| super::InvalidProof)
227}
228
229fn calculate_final_t(t: Iterations, delta: usize) -> u64 {
230 let mut curr_t = t.0;
231 let mut ts = vec![];
232 while curr_t != 2 {
233 ts.push(curr_t);
234 curr_t >>= 1;
235 if curr_t & 1 == 1 {
236 curr_t += 1
237 }
238 }
239 ts.push(2);
240 ts.push(1);
241 assert!(ts.len() >= delta);
242 ts[ts.len() - delta]
243}
244
245pub fn generate_proof<T, U, V>(
246 x: V,
247 iterations: Iterations,
248 delta: usize,
249 powers: &T,
250 generate_r_value: &U,
251 int_size_bits: usize,
252) -> Vec<V>
253where
254 T: for<'a> Index<&'a u64, Output = V>,
255 U: Fn(&V, &V, &V, usize) -> V::BigNum,
256 V: ClassGroup,
257 for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
258 for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
259{
260 let identity = x.identity();
261 let i = approximate_i(iterations);
262 let mut mus = vec![];
263 let mut rs: Vec<V::BigNum> = vec![];
264 let mut x_p = vec![x];
265 let mut curr_t = iterations.0;
266
267 let mut y_p = vec![powers[&curr_t].clone()];
268
269 let mut ts = vec![];
270
271 let final_t = calculate_final_t(iterations, delta);
272
273 let mut round_index = 0;
274 while curr_t != final_t {
275 assert_eq!(curr_t & 1, 0);
276 let half_t = curr_t >> 1;
277 ts.push(half_t);
278 assert!(round_index < 63);
279 let denominator: u64 = 1 << (round_index + 1);
280
281 mus.push(if round_index < i {
282 let mut mu = identity.clone();
283 for numerator in (1..denominator).step_by(2) {
284 let num_bits = 62 - denominator.leading_zeros() as usize;
285 let mut r_prod: V::BigNum = One::one();
286 for b in (0..num_bits).rev() {
287 if 0 == (numerator & (1 << (b + 1))) {
288 r_prod *= &rs[num_bits - b - 1]
289 }
290 }
291 let mut t_sum = half_t;
292 for b in 0..num_bits {
293 if 0 != (numerator & (1 << (b + 1))) {
294 t_sum += ts[num_bits - b - 1]
295 }
296 }
297 let mut power = powers[&t_sum].clone();
298 power.pow(r_prod);
299 mu *= &power;
300 }
301 mu
302 } else {
303 let mut mu = x_p.last().unwrap().clone();
304 for _ in 0..half_t {
305 mu *= &mu.clone()
306 }
307 mu
308 });
309 let mut mu: V = mus.last().unwrap().clone();
310 let last_r: V::BigNum = generate_r_value(&x_p[0], &y_p[0], &mu, int_size_bits);
311 assert!(last_r >= Zero::zero());
312 rs.push(last_r.clone());
313 {
314 let mut last_x: V = x_p.last().unwrap().clone();
315 last_x.pow(last_r.clone());
316 last_x *= μ
317 x_p.push(last_x);
318 }
319 mu.pow(last_r);
320 mu *= y_p.last().unwrap();
321 y_p.push(mu);
322 curr_t >>= 1;
323 if curr_t & 1 != 0 {
324 curr_t += 1;
325 y_p.last_mut().unwrap().square();
326 }
327 round_index += 1
328 }
329 if cfg!(debug_assertions) {
330 let mut last_y = y_p.last().unwrap().clone();
331 let mut last_x = x_p.last().unwrap().clone();
332 let one: V::BigNum = 1u64.into();
333 last_y.pow(one.clone());
334 assert_eq!(last_y, y_p.last().unwrap().clone());
335 last_x.pow(one << final_t as usize);
336 }
337 mus
338}
339
340pub fn verify_proof<T, U, V>(
341 x_initial: &V,
342 y_initial: &V,
343 proof: T,
344 t: Iterations,
345 delta: usize,
346 generate_r_value: &U,
347 int_size_bits: usize,
348) -> Result<(), ()>
349where
350 T: IntoIterator<Item = V>,
351 U: Fn(&V, &V, &V, usize) -> V::BigNum,
352 V: ClassGroup,
353 for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
354 for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
355{
356 let mut one: V::BigNum = One::one();
357 let (mut x, mut y): (V, V) = (x_initial.clone(), y_initial.clone());
358 let final_t = calculate_final_t(t, delta);
359 let mut curr_t = t.0;
360 for mut mu in proof {
361 assert!(
362 curr_t & 1 == 0,
363 "Cannot have an odd number of iterations remaining"
364 );
365 let r = generate_r_value(x_initial, y_initial, &mu, int_size_bits);
366 x.pow(r.clone());
367 x *= μ
368 mu.pow(r);
369 y *= μ
370
371 curr_t >>= 1;
372 if curr_t & 1 != 0 {
373 curr_t += 1;
374 y.square();
375 }
376 }
377 one <<= final_t as _;
378 x.pow(one);
379 if x == y {
380 Ok(())
381 } else {
382 Err(())
383 }
384}
385
386#[derive(Debug, Clone)]
387pub struct PietrzakVDF {
388 int_size_bits: u16,
389}
390use super::InvalidIterations as Bad;
391
392#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
393pub struct PietrzakVDFParams(pub u16);
394impl super::VDFParams for PietrzakVDFParams {
395 type VDF = PietrzakVDF;
396 fn new(self) -> Self::VDF {
397 PietrzakVDF {
398 int_size_bits: self.0,
399 }
400 }
401}
402
403impl super::VDF for PietrzakVDF {
404 fn check_difficulty(&self, difficulty: u64) -> Result<(), Bad> {
405 Iterations::new(difficulty)
406 .map_err(|x| Bad(format!("{}", x)))
407 .map(drop)
408 }
409 fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
410 Ok(create_proof_of_time_pietrzak::<GmpClassGroup>(
411 challenge,
412 Iterations::new(difficulty).map_err(|x| Bad(format!("{}", x)))?,
413 self.int_size_bits,
414 ))
415 }
416
417 fn verify(
418 &self,
419 challenge: &[u8],
420 difficulty: u64,
421 alleged_solution: &[u8],
422 ) -> Result<(), super::InvalidProof> {
423 check_proof_of_time_pietrzak::<GmpClassGroup>(
424 challenge,
425 alleged_solution,
426 difficulty,
427 self.int_size_bits,
428 )
429 }
430}
431
432#[cfg(test)]
433mod test {
434 use super::*;
435 #[test]
436 fn check_approximate_i() {
437 assert_eq!(approximate_i(Iterations(534)), 2);
438 assert_eq!(approximate_i(Iterations(134)), 1);
439 assert_eq!(approximate_i(Iterations(1024)), 2);
440 }
441 #[test]
442 fn check_cache_indices() {
443 assert_eq!(cache_indices_for_count(Iterations(66))[..], [33, 66]);
444 assert_eq!(
445 cache_indices_for_count(Iterations(534))[..],
446 [134, 267, 401, 534]
447 );
448 }
449
450 #[test]
451 fn check_calculate_final_t() {
452 assert_eq!(calculate_final_t(Iterations(1024), 8), 128);
453 assert_eq!(calculate_final_t(Iterations(1000), 8), 126);
454 assert_eq!(calculate_final_t(Iterations(100), 8), 100);
455 }
456 #[test]
457 fn check_assuptions_about_stdlib() {
458 assert_eq!(62 - u64::leading_zeros(1024u64), 9);
459 let mut q: Vec<_> = (1..4).step_by(2).collect();
460 assert_eq!(q[..], [1, 3]);
461 q = (1..3).step_by(2).collect();
462 assert_eq!(q[..], [1]);
463 q = (1..2).step_by(2).collect();
464 assert_eq!(q[..], [1]);
465 }
466}