vdf/
proof_wesolowski.rs

1// Copyright 2018 POA Networks Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::proof_of_time::{iterate_squarings, serialize};
16use classgroup::{gmp_classgroup::GmpClassGroup, BigNum, BigNumExt, ClassGroup};
17use sha2::{digest::FixedOutput, Digest, Sha256};
18use std::{cmp::Eq, collections::HashMap, hash::Hash, mem, u64, usize};
19
20#[derive(Debug, Clone)]
21pub struct WesolowskiVDF {
22    int_size_bits: u16,
23}
24use super::InvalidIterations as Bad;
25
26#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
27pub struct WesolowskiVDFParams(pub u16);
28impl super::VDFParams for WesolowskiVDFParams {
29    type VDF = WesolowskiVDF;
30    fn new(self) -> Self::VDF {
31        WesolowskiVDF {
32            int_size_bits: self.0,
33        }
34    }
35}
36
37impl super::VDF for WesolowskiVDF {
38    fn check_difficulty(&self, _difficulty: u64) -> Result<(), Bad> {
39        Ok(())
40    }
41    fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
42        if difficulty > usize::MAX as u64 {
43            Err(Bad("Cannot have more that usize::MAX iterations".to_owned()))
44        } else {
45            Ok(create_proof_of_time_wesolowski::<
46                <GmpClassGroup as ClassGroup>::BigNum,
47                GmpClassGroup,
48            >(
49                challenge, difficulty as usize, self.int_size_bits
50            ))
51        }
52    }
53
54    fn verify(
55        &self,
56        challenge: &[u8],
57        difficulty: u64,
58        alleged_solution: &[u8],
59    ) -> Result<(), super::InvalidProof> {
60        check_proof_of_time_wesolowski::<<GmpClassGroup as ClassGroup>::BigNum, GmpClassGroup>(
61            challenge,
62            alleged_solution,
63            difficulty,
64            self.int_size_bits,
65        )
66        .map_err(|()| super::InvalidProof)
67    }
68}
69/// To quote the original Python code:
70///
71/// > Create `L` and `k` parameters from papers, based on how many iterations
72/// > need to be performed, and how much memory should be used.
73pub fn approximate_parameters(t: f64) -> (usize, u8, u64) {
74    let log_memory = (10_000_000.0f64).log2();
75    let log_t = (t as f64).log2();
76    let l = if log_t - log_memory > 0. {
77        2.0f64.powf(log_memory - 20.).ceil()
78    } else {
79        1.
80    };
81
82    let intermediate = t * (2.0f64).ln() / (2.0 * l);
83    let k = (intermediate.ln() - intermediate.ln().ln() + 0.25)
84        .round()
85        .max(1.);
86
87    let w = (t / (t / k + l * (2.0f64).powf(k + 1.0)) - 2.0).floor();
88    (l as _, k as _, w as _)
89}
90
91fn u64_to_bytes(q: u64) -> [u8; 8] {
92    if false {
93        // This use of `std::mem::transumte` is correct, but still not justified.
94        unsafe { std::mem::transmute(q.to_be()) }
95    } else {
96        [
97            (q >> 56) as u8,
98            (q >> 48) as u8,
99            (q >> 40) as u8,
100            (q >> 32) as u8,
101            (q >> 24) as u8,
102            (q >> 16) as u8,
103            (q >> 8) as u8,
104            q as u8,
105        ]
106    }
107}
108
109/// Quote:
110///
111/// > Creates a random prime based on input s.
112fn hash_prime<T: BigNum>(seed: &[&[u8]]) -> T {
113    let mut j = 0u64;
114    loop {
115        let mut hasher = Sha256::new();
116        hasher.input(b"prime");
117        hasher.input(u64_to_bytes(j));
118        for i in seed {
119            hasher.input(i);
120        }
121        let n = T::from(&hasher.fixed_result()[..16]);
122        if n.probab_prime(2) {
123            break n;
124        }
125        j += 1;
126    }
127}
128
129/// Quote:
130///
131/// > Get“s the ith block of `2^T // B`, such that `sum(get_block(i) * 2^(k*i))
132/// > = t^T // B`
133fn get_block<T: BigNumExt>(i: u64, k: u8, t: u64, b: &T) -> T {
134    let mut res = T::from(0);
135    let two = T::from(2);
136    res.mod_powm(&two, &T::from(t - u64::from(k) * (i + 1)), b);
137    res *= &((two >> 1) << (k as usize));
138    res / b
139}
140
141fn eval_optimized<T, U: BigNumExt, L: ClassGroup<BigNum = U> + Eq + Hash>(
142    h: &L,
143    b: &U,
144    t: usize,
145    k: u8,
146    l: usize,
147    powers: &T,
148) -> L
149where
150    T: for<'a> std::ops::Index<&'a u64, Output = L>,
151{
152    assert!(k > 0, "k cannot be zero");
153    assert!(l > 0, "l cannot be zero");
154    let kl = (k as usize)
155        .checked_mul(l)
156        .expect("computing k*l overflowed a u64");
157    assert!(kl <= u64::MAX as _);
158    assert!((kl as u64) < (1u64 << 53), "k*l overflowed an f64");
159    assert!((t as u64) < (1u64 << 53), "t overflows an f64");
160    assert!(
161        k < (mem::size_of::<usize>() << 3) as u8,
162        "k must be less than the number of bits in a usize"
163    );
164    let k1 = k >> 1;
165    let k0 = k - k1;
166    let mut x = h.identity();
167    let identity = h.identity();
168    let k_exp = 1usize << k;
169    let k0_exp = 1usize << k0;
170    let k1_exp = 1usize << k1;
171    for j in (0..l).rev() {
172        x.pow(U::from(k_exp as u64));
173        let mut ys: HashMap<U, L> = HashMap::new();
174        for b in 0..1usize << k {
175            ys.entry(U::from(b as u64))
176                .or_insert_with(|| identity.clone());
177        }
178        let end_of_loop = ((t as f64) / kl as f64).ceil() as usize;
179        assert!(end_of_loop == 0 || (end_of_loop as u64 - 1).checked_mul(l as u64).is_some());
180        for i in 0..end_of_loop {
181            if t < k as usize * (i * l + j + 1) {
182                continue;
183            }
184            let b = get_block((i as u64) * (l as u64), k, t as _, b);
185            *ys.get_mut(&b).unwrap() *= &powers[&((i * kl) as _)];
186        }
187
188        for b1 in 0..k1_exp {
189            let mut z = identity.clone();
190            for b0 in 0..k0_exp {
191                z *= &ys[&U::from((b1 * k0_exp + b0) as u64)]
192            }
193            z.pow(U::from((b1 as u64) * (k0_exp as u64)));
194            x *= &z;
195        }
196
197        for b0 in 0..k0_exp {
198            let mut z = identity.clone();
199            for b1 in 0..k1_exp {
200                z *= &ys[&U::from((b1 * k0_exp + b0) as u64)];
201            }
202            z.pow(U::from(b0 as u64));
203            x *= &z;
204        }
205    }
206    x
207}
208
209pub fn generate_proof<U, T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
210    x: &V,
211    iterations: u64,
212    k: u8,
213    l: usize,
214    powers: &U,
215    int_size_bits: usize,
216) -> V
217where
218    U: for<'a> std::ops::Index<&'a u64, Output = V>,
219{
220    let element_len = 2 * ((int_size_bits + 16) >> 4);
221    let mut x_buf = vec![0; element_len];
222    x.serialize(&mut x_buf[..])
223        .expect(super::INCORRECT_BUFFER_SIZE);
224    let mut y_buf = vec![0; element_len];
225    powers[&iterations]
226        .serialize(&mut y_buf[..])
227        .expect(super::INCORRECT_BUFFER_SIZE);
228    let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
229    eval_optimized(&x, &b, iterations as _, k, l, powers)
230}
231
232/// Verify a proof, according to the Wesolowski paper.
233pub fn verify_proof<T: BigNum, V: ClassGroup<BigNum = T>>(
234    mut x: V,
235    y: &V,
236    mut proof: V,
237    t: u64,
238    int_size_bits: usize,
239) -> Result<(), ()> {
240    let element_len = 2 * ((int_size_bits + 16) >> 4);
241    let mut x_buf = vec![0; element_len];
242    x.serialize(&mut x_buf[..])
243        .expect(super::INCORRECT_BUFFER_SIZE);
244    let mut y_buf = vec![0; element_len];
245    y.serialize(&mut y_buf[..])
246        .expect(super::INCORRECT_BUFFER_SIZE);
247    let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
248    let mut r = T::from(0);
249    r.mod_powm(&T::from(2u64), &T::from(t), &b);
250    proof.pow(b);
251    x.pow(r);
252    proof *= &x;
253    if &proof == y {
254        Ok(())
255    } else {
256        Err(())
257    }
258}
259
260pub fn create_proof_of_time_wesolowski<T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
261    challenge: &[u8],
262    iterations: usize,
263    int_size_bits: u16,
264) -> Vec<u8>
265where
266    for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
267    for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
268{
269    let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
270    let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant);
271    assert!((iterations as u128) < (1u128 << 53));
272    let (l, k, _) = approximate_parameters(iterations as f64);
273    let q = l.checked_mul(k as _).expect("bug");
274    let powers = iterate_squarings(
275        x.clone(),
276        (0..=iterations / q + 1)
277            .map(|i| i * q)
278            .chain(Some(iterations))
279            .map(|x| x as _),
280    );
281    let proof = generate_proof(&x, iterations as _, k, l, &powers, int_size_bits.into());
282    serialize(&[proof], &powers[&(iterations as _)], int_size_bits.into())
283}
284
285pub fn check_proof_of_time_wesolowski<T: BigNum, V: ClassGroup<BigNum = T>>(
286    challenge: &[u8],
287    proof_blob: &[u8],
288    iterations: u64,
289    int_size_bits: u16,
290) -> Result<(), ()>
291where
292    T: BigNumExt,
293{
294    let discriminant: T = super::create_discriminant::create_discriminant(challenge, int_size_bits);
295    let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant.clone());
296    if (usize::MAX - 16) < int_size_bits.into() {
297        return Err(());
298    }
299    let int_size = (usize::from(int_size_bits) + 16) >> 4;
300    if int_size * 4 != proof_blob.len() {
301        return Err(());
302    }
303    let (result_bytes, proof_bytes) = proof_blob.split_at(2 * int_size);
304    let proof = ClassGroup::from_bytes(proof_bytes, discriminant.clone());
305    let y = ClassGroup::from_bytes(result_bytes, discriminant);
306
307    verify_proof(x, &y, proof, iterations, int_size_bits.into())
308}