Skip to main content

vdf_rs/
proof_pietrzak.rs

1// Copyright 2018 Chia Network Inc and POA Networks Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use super::proof_of_time::{deserialize_proof, iterate_squarings, serialize};
15use vdf_classgroup::{gmp_classgroup::GmpClassGroup, BigNumExt, ClassGroup};
16use num_traits::{One, Zero};
17use std::{fmt, num::ParseIntError, ops::Index, str::FromStr, u64, usize};
18
19#[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone, Debug)]
20pub struct Iterations(u64);
21
22#[derive(PartialEq, Eq, Hash, Ord, PartialOrd, Copy, Clone, Debug)]
23pub enum InvalidIterations {
24    OddNumber(u64),
25    LessThan66(u64),
26}
27
28#[derive(PartialEq, Eq, Clone, Debug)]
29pub struct ParseIterationsError {
30    kind: Result<InvalidIterations, ParseIntError>,
31}
32
33impl From<InvalidIterations> for ParseIterationsError {
34    fn from(t: InvalidIterations) -> Self {
35        Self { kind: Ok(t) }
36    }
37}
38
39impl From<ParseIntError> for ParseIterationsError {
40    fn from(t: ParseIntError) -> Self {
41        Self { kind: Err(t) }
42    }
43}
44
45impl fmt::Display for InvalidIterations {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match *self {
48            InvalidIterations::OddNumber(s) => {
49                write!(f, "Pietrzak iterations must be an even number, not {}", s)
50            }
51            InvalidIterations::LessThan66(s) => write!(
52                f,
53                "Pietrzak proof-of-time must run for at least 66 iterations, not {}",
54                s
55            ),
56        }
57    }
58}
59
60impl From<Iterations> for u64 {
61    fn from(t: Iterations) -> u64 {
62        t.0
63    }
64}
65
66impl fmt::Display for ParseIterationsError {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self.kind {
69            Ok(ref q) => <InvalidIterations as fmt::Display>::fmt(q, f),
70            Err(ref q) => <ParseIntError as fmt::Display>::fmt(q, f),
71        }
72    }
73}
74
75impl FromStr for Iterations {
76    type Err = ParseIterationsError;
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        Self::new(s.parse::<u64>().map_err(ParseIterationsError::from)?)
79            .map_err(ParseIterationsError::from)
80    }
81}
82
83impl Iterations {
84    pub fn new<T: Into<u64>>(iterations: T) -> Result<Iterations, InvalidIterations> {
85        let iterations = iterations.into();
86        if iterations & 1 != 0 {
87            Err(InvalidIterations::OddNumber(iterations))
88        } else if iterations < 66 {
89            Err(InvalidIterations::LessThan66(iterations))
90        } else {
91            Ok(Iterations(iterations))
92        }
93    }
94}
95
96/// Selects a reasonable choice of cache size.
97fn approximate_i(t: Iterations) -> u64 {
98    let x: f64 = (((t.0 >> 1) as f64) / 8.) * 2.0f64.ln();
99    let w = x.ln() - x.ln().ln() + 0.25;
100    (w / (2. * 2.0f64.ln())).round() as _
101}
102
103fn sum_combinations<'a, T: IntoIterator<Item = &'a u64>>(numbers: T) -> Vec<u64> {
104    let mut combinations = vec![0];
105    for i in numbers {
106        let mut new_combinations = combinations.clone();
107        for j in combinations {
108            new_combinations.push(i + j)
109        }
110        combinations = new_combinations
111    }
112    combinations.remove(0);
113    combinations
114}
115
116fn cache_indices_for_count(t: Iterations) -> Vec<u64> {
117    let i: u64 = approximate_i(t);
118    let mut curr_t = t.0;
119    let mut intermediate_ts = vec![];
120    for _ in 0..i {
121        curr_t >>= 1;
122        intermediate_ts.push(curr_t);
123        if curr_t & 1 != 0 {
124            curr_t += 1
125        }
126    }
127    let mut cache_indices = sum_combinations(&intermediate_ts);
128    cache_indices.sort();
129    cache_indices.push(t.0);
130    cache_indices
131}
132
133fn generate_r_value<T>(x: &T, y: &T, sqrt_mu: &T, int_size_bits: usize) -> T::BigNum
134where
135    T: ClassGroup,
136    for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
137    for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
138{
139    use sha2::{Digest, Sha256};
140
141    let size = (int_size_bits + 16) >> 4;
142    let mut v = Vec::with_capacity(size * 2);
143    for _ in 0..size * 2 {
144        v.push(0)
145    }
146    let mut hasher = Sha256::new();
147    for i in &[&x, &y, &sqrt_mu] {
148        i.serialize(&mut v).expect(super::INCORRECT_BUFFER_SIZE);
149        hasher.update(&v);
150    }
151    let res = hasher.finalize();
152    T::unsigned_deserialize_bignum(&res[..16])
153}
154
155fn create_proof_of_time_pietrzak<T>(
156    challenge: &[u8],
157    iterations: Iterations,
158    int_size_bits: u16,
159) -> Vec<u8>
160where
161    T: ClassGroup,
162    <T as ClassGroup>::BigNum: BigNumExt,
163    for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
164    for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
165{
166    let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
167    let x = T::from_ab_discriminant(2.into(), 1.into(), discriminant);
168
169    let delta = 8;
170    let powers_to_calculate = cache_indices_for_count(iterations);
171    let powers = iterate_squarings(x.clone(), powers_to_calculate.iter().cloned());
172    let proof: Vec<T> = generate_proof(
173        x,
174        iterations,
175        delta,
176        &powers,
177        &generate_r_value,
178        usize::from(int_size_bits),
179    );
180    serialize(
181        &proof,
182        &powers[&iterations.into()],
183        usize::from(int_size_bits),
184    )
185}
186
187pub fn check_proof_of_time_pietrzak<T>(
188    challenge: &[u8],
189    proof_blob: &[u8],
190    iterations: u64,
191    length_in_bits: u16,
192) -> Result<(), super::InvalidProof>
193where
194    T: ClassGroup,
195    T::BigNum: BigNumExt,
196    for<'a, 'b> &'a T: std::ops::Mul<&'b T, Output = T>,
197    for<'a, 'b> &'a T::BigNum: std::ops::Mul<&'b T::BigNum, Output = T::BigNum>,
198{
199    let discriminant = super::create_discriminant::create_discriminant(&challenge, length_in_bits);
200    let x = T::from_ab_discriminant(2.into(), 1.into(), discriminant);
201    let iterations = Iterations::new(iterations).map_err(|_| super::InvalidProof)?;
202    if usize::MAX - 16 < length_in_bits.into() {
203        // Proof way too long.
204        return Err(super::InvalidProof);
205    }
206    let length: usize = (usize::from(length_in_bits) + 16usize) >> 4;
207    if proof_blob.len() < 2 * length {
208        // Invalid length of proof
209        return Err(super::InvalidProof);
210    }
211    let result_bytes = &proof_blob[..length * 2];
212    let proof_bytes = &proof_blob[length * 2..];
213    let discriminant = x.discriminant().clone();
214    let proof =
215        deserialize_proof(proof_bytes, &discriminant, length).map_err(|()| super::InvalidProof)?;
216    let y = T::from_bytes(result_bytes, discriminant);
217    verify_proof(
218        &x,
219        &y,
220        proof,
221        iterations,
222        8,
223        &generate_r_value,
224        length_in_bits.into(),
225    )
226    .map_err(|()| super::InvalidProof)
227}
228
229fn calculate_final_t(t: Iterations, delta: usize) -> u64 {
230    let mut curr_t = t.0;
231    let mut ts = vec![];
232    while curr_t != 2 {
233        ts.push(curr_t);
234        curr_t >>= 1;
235        if curr_t & 1 == 1 {
236            curr_t += 1
237        }
238    }
239    ts.push(2);
240    ts.push(1);
241    assert!(ts.len() >= delta);
242    ts[ts.len() - delta]
243}
244
245pub fn generate_proof<T, U, V>(
246    x: V,
247    iterations: Iterations,
248    delta: usize,
249    powers: &T,
250    generate_r_value: &U,
251    int_size_bits: usize,
252) -> Vec<V>
253where
254    T: for<'a> Index<&'a u64, Output = V>,
255    U: Fn(&V, &V, &V, usize) -> V::BigNum,
256    V: ClassGroup,
257    for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
258    for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
259{
260    let identity = x.identity();
261    let i = approximate_i(iterations);
262    let mut mus = vec![];
263    let mut rs: Vec<V::BigNum> = vec![];
264    let mut x_p = vec![x];
265    let mut curr_t = iterations.0;
266
267    let mut y_p = vec![powers[&curr_t].clone()];
268
269    let mut ts = vec![];
270
271    let final_t = calculate_final_t(iterations, delta);
272
273    let mut round_index = 0;
274    while curr_t != final_t {
275        assert_eq!(curr_t & 1, 0);
276        let half_t = curr_t >> 1;
277        ts.push(half_t);
278        assert!(round_index < 63);
279        let denominator: u64 = 1 << (round_index + 1);
280
281        mus.push(if round_index < i {
282            let mut mu = identity.clone();
283            for numerator in (1..denominator).step_by(2) {
284                let num_bits = 62 - denominator.leading_zeros() as usize;
285                let mut r_prod: V::BigNum = One::one();
286                for b in (0..num_bits).rev() {
287                    if 0 == (numerator & (1 << (b + 1))) {
288                        r_prod *= &rs[num_bits - b - 1]
289                    }
290                }
291                let mut t_sum = half_t;
292                for b in 0..num_bits {
293                    if 0 != (numerator & (1 << (b + 1))) {
294                        t_sum += ts[num_bits - b - 1]
295                    }
296                }
297                let mut power = powers[&t_sum].clone();
298                power.pow(r_prod);
299                mu *= &power;
300            }
301            mu
302        } else {
303            let mut mu = x_p.last().unwrap().clone();
304            for _ in 0..half_t {
305                mu *= &mu.clone()
306            }
307            mu
308        });
309        let mut mu: V = mus.last().unwrap().clone();
310        let last_r: V::BigNum = generate_r_value(&x_p[0], &y_p[0], &mu, int_size_bits);
311        assert!(last_r >= Zero::zero());
312        rs.push(last_r.clone());
313        {
314            let mut last_x: V = x_p.last().unwrap().clone();
315            last_x.pow(last_r.clone());
316            last_x *= &mu;
317            x_p.push(last_x);
318        }
319        mu.pow(last_r);
320        mu *= y_p.last().unwrap();
321        y_p.push(mu);
322        curr_t >>= 1;
323        if curr_t & 1 != 0 {
324            curr_t += 1;
325            y_p.last_mut().unwrap().square();
326        }
327        round_index += 1
328    }
329    if cfg!(debug_assertions) {
330        let mut last_y = y_p.last().unwrap().clone();
331        let mut last_x = x_p.last().unwrap().clone();
332        let one: V::BigNum = 1u64.into();
333        last_y.pow(one.clone());
334        assert_eq!(last_y, y_p.last().unwrap().clone());
335        last_x.pow(one << final_t as usize);
336    }
337    mus
338}
339
340pub fn verify_proof<T, U, V>(
341    x_initial: &V,
342    y_initial: &V,
343    proof: T,
344    t: Iterations,
345    delta: usize,
346    generate_r_value: &U,
347    int_size_bits: usize,
348) -> Result<(), ()>
349where
350    T: IntoIterator<Item = V>,
351    U: Fn(&V, &V, &V, usize) -> V::BigNum,
352    V: ClassGroup,
353    for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
354    for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
355{
356    let mut one: V::BigNum = One::one();
357    let (mut x, mut y): (V, V) = (x_initial.clone(), y_initial.clone());
358    let final_t = calculate_final_t(t, delta);
359    let mut curr_t = t.0;
360    for mut mu in proof {
361        assert!(
362            curr_t & 1 == 0,
363            "Cannot have an odd number of iterations remaining"
364        );
365        let r = generate_r_value(x_initial, y_initial, &mu, int_size_bits);
366        x.pow(r.clone());
367        x *= &mu;
368        mu.pow(r);
369        y *= &mu;
370
371        curr_t >>= 1;
372        if curr_t & 1 != 0 {
373            curr_t += 1;
374            y.square();
375        }
376    }
377    one <<= final_t as _;
378    x.pow(one);
379    if x == y {
380        Ok(())
381    } else {
382        Err(())
383    }
384}
385
386#[derive(Debug, Clone)]
387pub struct PietrzakVDF {
388    int_size_bits: u16,
389}
390use super::InvalidIterations as Bad;
391
392#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
393pub struct PietrzakVDFParams(pub u16);
394impl super::VDFParams for PietrzakVDFParams {
395    type VDF = PietrzakVDF;
396    fn new(self) -> Self::VDF {
397        PietrzakVDF {
398            int_size_bits: self.0,
399        }
400    }
401}
402
403impl super::VDF for PietrzakVDF {
404    fn check_difficulty(&self, difficulty: u64) -> Result<(), Bad> {
405        Iterations::new(difficulty)
406            .map_err(|x| Bad(format!("{}", x)))
407            .map(drop)
408    }
409    fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
410        Ok(create_proof_of_time_pietrzak::<GmpClassGroup>(
411            challenge,
412            Iterations::new(difficulty).map_err(|x| Bad(format!("{}", x)))?,
413            self.int_size_bits,
414        ))
415    }
416
417    fn verify(
418        &self,
419        challenge: &[u8],
420        difficulty: u64,
421        alleged_solution: &[u8],
422    ) -> Result<(), super::InvalidProof> {
423        check_proof_of_time_pietrzak::<GmpClassGroup>(
424            challenge,
425            alleged_solution,
426            difficulty,
427            self.int_size_bits,
428        )
429    }
430}
431
432#[cfg(test)]
433mod test {
434    use super::*;
435    #[test]
436    fn check_approximate_i() {
437        assert_eq!(approximate_i(Iterations(534)), 2);
438        assert_eq!(approximate_i(Iterations(134)), 1);
439        assert_eq!(approximate_i(Iterations(1024)), 2);
440    }
441    #[test]
442    fn check_cache_indices() {
443        assert_eq!(cache_indices_for_count(Iterations(66))[..], [33, 66]);
444        assert_eq!(
445            cache_indices_for_count(Iterations(534))[..],
446            [134, 267, 401, 534]
447        );
448    }
449
450    #[test]
451    fn check_calculate_final_t() {
452        assert_eq!(calculate_final_t(Iterations(1024), 8), 128);
453        assert_eq!(calculate_final_t(Iterations(1000), 8), 126);
454        assert_eq!(calculate_final_t(Iterations(100), 8), 100);
455    }
456    #[test]
457    fn check_assuptions_about_stdlib() {
458        assert_eq!(62 - u64::leading_zeros(1024u64), 9);
459        let mut q: Vec<_> = (1..4).step_by(2).collect();
460        assert_eq!(q[..], [1, 3]);
461        q = (1..3).step_by(2).collect();
462        assert_eq!(q[..], [1]);
463        q = (1..2).step_by(2).collect();
464        assert_eq!(q[..], [1]);
465    }
466}