Skip to main content

vdf_rs/
proof_wesolowski.rs

1// Copyright 2018 POA Networks Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::proof_of_time::{iterate_squarings, serialize};
16use vdf_classgroup::{gmp_classgroup::GmpClassGroup, BigNum, BigNumExt, ClassGroup};
17use sha2::{Digest, Sha256};
18use std::{cmp::Eq, collections::HashMap, hash::Hash, mem, u64, usize};
19
20#[derive(Debug, Clone)]
21pub struct WesolowskiVDF {
22    int_size_bits: u16,
23}
24use super::InvalidIterations as Bad;
25
26#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Debug)]
27pub struct WesolowskiVDFParams(pub u16);
28impl super::VDFParams for WesolowskiVDFParams {
29    type VDF = WesolowskiVDF;
30    fn new(self) -> Self::VDF {
31        WesolowskiVDF {
32            int_size_bits: self.0,
33        }
34    }
35}
36
37impl super::VDF for WesolowskiVDF {
38    fn check_difficulty(&self, _difficulty: u64) -> Result<(), Bad> {
39        Ok(())
40    }
41    fn solve(&self, challenge: &[u8], difficulty: u64) -> Result<Vec<u8>, Bad> {
42        if difficulty > usize::MAX as u64 {
43            Err(Bad("Cannot have more that usize::MAX iterations".to_owned()))
44        } else {
45            Ok(create_proof_of_time_wesolowski::<
46                <GmpClassGroup as ClassGroup>::BigNum,
47                GmpClassGroup,
48            >(
49                challenge, difficulty as usize, self.int_size_bits
50            ))
51        }
52    }
53
54    fn verify(
55        &self,
56        challenge: &[u8],
57        difficulty: u64,
58        alleged_solution: &[u8],
59    ) -> Result<(), super::InvalidProof> {
60        check_proof_of_time_wesolowski::<<GmpClassGroup as ClassGroup>::BigNum, GmpClassGroup>(
61            challenge,
62            alleged_solution,
63            difficulty,
64            self.int_size_bits,
65        )
66        .map_err(|()| super::InvalidProof)
67    }
68}
69/// To quote the original Python code:
70///
71/// > Create `L` and `k` parameters from papers, based on how many iterations
72/// > need to be performed, and how much memory should be used.
73pub fn approximate_parameters(t: f64) -> (usize, u8, u64) {
74    let log_memory = (10_000_000.0f64).log2();
75    let log_t = (t as f64).log2();
76    let l = if log_t - log_memory > 0. {
77        2.0f64.powf(log_memory - 20.).ceil()
78    } else {
79        1.
80    };
81
82    let intermediate = t * (2.0f64).ln() / (2.0 * l);
83    let k = (intermediate.ln() - intermediate.ln().ln() + 0.25)
84        .round()
85        .max(1.);
86
87    let w = (t / (t / k + l * (2.0f64).powf(k + 1.0)) - 2.0).floor();
88    (l as _, k as _, w as _)
89}
90
91fn u64_to_bytes(q: u64) -> [u8; 8] {
92    // Use modern idiomatic Rust for byte conversion
93    q.to_be_bytes()
94}
95
96/// Quote:
97///
98/// > Creates a random prime based on input s.
99fn hash_prime<T: BigNum>(seed: &[&[u8]]) -> T {
100    let mut j = 0u64;
101    loop {
102        let mut hasher = Sha256::new();
103        hasher.update(b"prime");
104        hasher.update(u64_to_bytes(j));
105        for i in seed {
106            hasher.update(i);
107        }
108        let n = T::from(&hasher.finalize()[..16]);
109        if n.probab_prime(2) {
110            break n;
111        }
112        j += 1;
113    }
114}
115
116/// Quote:
117///
118/// > Get“s the ith block of `2^T // B`, such that `sum(get_block(i) * 2^(k*i))
119/// > = t^T // B`
120fn get_block<T: BigNumExt>(i: u64, k: u8, t: u64, b: &T) -> T {
121    let mut res = T::from(0);
122    let two = T::from(2);
123    res.mod_powm(&two, &T::from(t - u64::from(k) * (i + 1)), b);
124    res *= &((two >> 1) << (k as usize));
125    res / b
126}
127
128fn eval_optimized<T, U: BigNumExt, L: ClassGroup<BigNum = U> + Eq + Hash>(
129    h: &L,
130    b: &U,
131    t: usize,
132    k: u8,
133    l: usize,
134    powers: &T,
135) -> L
136where
137    T: for<'a> std::ops::Index<&'a u64, Output = L>,
138{
139    assert!(k > 0, "k cannot be zero");
140    assert!(l > 0, "l cannot be zero");
141    let kl = (k as usize)
142        .checked_mul(l)
143        .expect("computing k*l overflowed a u64");
144    assert!(kl <= u64::MAX as _);
145    assert!((kl as u64) < (1u64 << 53), "k*l overflowed an f64");
146    assert!((t as u64) < (1u64 << 53), "t overflows an f64");
147    assert!(
148        k < (mem::size_of::<usize>() << 3) as u8,
149        "k must be less than the number of bits in a usize"
150    );
151    let k1 = k >> 1;
152    let k0 = k - k1;
153    let mut x = h.identity();
154    let identity = h.identity();
155    let k_exp = 1usize << k;
156    let k0_exp = 1usize << k0;
157    let k1_exp = 1usize << k1;
158    for j in (0..l).rev() {
159        x.pow(U::from(k_exp as u64));
160        let mut ys: HashMap<U, L> = HashMap::new();
161        for b in 0..1usize << k {
162            ys.entry(U::from(b as u64))
163                .or_insert_with(|| identity.clone());
164        }
165        let end_of_loop = ((t as f64) / kl as f64).ceil() as usize;
166        assert!(end_of_loop == 0 || (end_of_loop as u64 - 1).checked_mul(l as u64).is_some());
167        for i in 0..end_of_loop {
168            if t < k as usize * (i * l + j + 1) {
169                continue;
170            }
171            let b = get_block((i as u64) * (l as u64), k, t as _, b);
172            *ys.get_mut(&b).unwrap() *= &powers[&((i * kl) as _)];
173        }
174
175        for b1 in 0..k1_exp {
176            let mut z = identity.clone();
177            for b0 in 0..k0_exp {
178                z *= &ys[&U::from((b1 * k0_exp + b0) as u64)]
179            }
180            z.pow(U::from((b1 as u64) * (k0_exp as u64)));
181            x *= &z;
182        }
183
184        for b0 in 0..k0_exp {
185            let mut z = identity.clone();
186            for b1 in 0..k1_exp {
187                z *= &ys[&U::from((b1 * k0_exp + b0) as u64)];
188            }
189            z.pow(U::from(b0 as u64));
190            x *= &z;
191        }
192    }
193    x
194}
195
196pub fn generate_proof<U, T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
197    x: &V,
198    iterations: u64,
199    k: u8,
200    l: usize,
201    powers: &U,
202    int_size_bits: usize,
203) -> V
204where
205    U: for<'a> std::ops::Index<&'a u64, Output = V>,
206{
207    let element_len = 2 * ((int_size_bits + 16) >> 4);
208    let mut x_buf = vec![0; element_len];
209    x.serialize(&mut x_buf[..])
210        .expect(super::INCORRECT_BUFFER_SIZE);
211    let mut y_buf = vec![0; element_len];
212    powers[&iterations]
213        .serialize(&mut y_buf[..])
214        .expect(super::INCORRECT_BUFFER_SIZE);
215    let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
216    eval_optimized(&x, &b, iterations as _, k, l, powers)
217}
218
219/// Verify a proof, according to the Wesolowski paper.
220pub fn verify_proof<T: BigNum, V: ClassGroup<BigNum = T>>(
221    mut x: V,
222    y: &V,
223    mut proof: V,
224    t: u64,
225    int_size_bits: usize,
226) -> Result<(), ()> {
227    let element_len = 2 * ((int_size_bits + 16) >> 4);
228    let mut x_buf = vec![0; element_len];
229    x.serialize(&mut x_buf[..])
230        .expect(super::INCORRECT_BUFFER_SIZE);
231    let mut y_buf = vec![0; element_len];
232    y.serialize(&mut y_buf[..])
233        .expect(super::INCORRECT_BUFFER_SIZE);
234    let b = hash_prime(&[&x_buf[..], &y_buf[..]]);
235    let mut r = T::from(0);
236    r.mod_powm(&T::from(2u64), &T::from(t), &b);
237    proof.pow(b);
238    x.pow(r);
239    proof *= &x;
240    if &proof == y {
241        Ok(())
242    } else {
243        Err(())
244    }
245}
246
247pub fn create_proof_of_time_wesolowski<T: BigNumExt, V: ClassGroup<BigNum = T> + Eq + Hash>(
248    challenge: &[u8],
249    iterations: usize,
250    int_size_bits: u16,
251) -> Vec<u8>
252where
253    for<'a, 'b> &'a V: std::ops::Mul<&'b V, Output = V>,
254    for<'a, 'b> &'a V::BigNum: std::ops::Mul<&'b V::BigNum, Output = V::BigNum>,
255{
256    let discriminant = super::create_discriminant::create_discriminant(&challenge, int_size_bits);
257    let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant);
258    assert!((iterations as u128) < (1u128 << 53));
259    let (l, k, _) = approximate_parameters(iterations as f64);
260    let q = l.checked_mul(k as _).expect("bug");
261    let powers = iterate_squarings(
262        x.clone(),
263        (0..=iterations / q + 1)
264            .map(|i| i * q)
265            .chain(Some(iterations))
266            .map(|x| x as _),
267    );
268    let proof = generate_proof(&x, iterations as _, k, l, &powers, int_size_bits.into());
269    serialize(&[proof], &powers[&(iterations as _)], int_size_bits.into())
270}
271
272pub fn check_proof_of_time_wesolowski<T: BigNum, V: ClassGroup<BigNum = T>>(
273    challenge: &[u8],
274    proof_blob: &[u8],
275    iterations: u64,
276    int_size_bits: u16,
277) -> Result<(), ()>
278where
279    T: BigNumExt,
280{
281    let discriminant: T = super::create_discriminant::create_discriminant(challenge, int_size_bits);
282    let x = V::from_ab_discriminant(2.into(), 1.into(), discriminant.clone());
283    if (usize::MAX - 16) < int_size_bits.into() {
284        return Err(());
285    }
286    let int_size = (usize::from(int_size_bits) + 16) >> 4;
287    if int_size * 4 != proof_blob.len() {
288        return Err(());
289    }
290    let (result_bytes, proof_bytes) = proof_blob.split_at(2 * int_size);
291    let proof = ClassGroup::from_bytes(proof_bytes, discriminant.clone());
292    let y = ClassGroup::from_bytes(result_bytes, discriminant);
293
294    verify_proof(x, &y, proof, iterations, int_size_bits.into())
295}