Skip to main content

vdf_classgroup/gmp_classgroup/
mod.rs

1// Copyright 2018 Chia Network Inc and POA Networks Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14#![deny(unsafe_code)]
15use super::gmp::mpz::Mpz;
16use super::gmp::mpz::ProbabPrimeResult::NotPrime;
17use super::ClassGroup;
18use num_traits::{One, Zero};
19use std::{
20    borrow::Borrow,
21    cell::RefCell,
22    mem::swap,
23    ops::{Mul, MulAssign},
24};
25mod congruence;
26pub(super) mod ffi;
27
28#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Clone)]
29pub struct GmpClassGroup {
30    a: Mpz,
31    b: Mpz,
32    c: Mpz,
33    discriminant: Mpz,
34}
35
36#[derive(PartialEq, PartialOrd, Eq, Ord, Clone, Hash, Debug)]
37pub struct Ctx {
38    negative_a: Mpz,
39    r: Mpz,
40    denom: Mpz,
41    old_a: Mpz,
42    old_b: Mpz,
43    ra: Mpz,
44    s: Mpz,
45    x: Mpz,
46    congruence_context: congruence::CongruenceContext,
47    h: Mpz,
48    w: Mpz,
49    m: Mpz,
50    u: Mpz,
51    a: Mpz,
52    l: Mpz,
53    j: Mpz,
54    b: Mpz,
55    k: Mpz,
56    t: Mpz,
57    mu: Mpz,
58    v: Mpz,
59    sigma: Mpz,
60    lambda: Mpz,
61}
62
63thread_local! {
64    static CTX: RefCell<Ctx> = Default::default();
65}
66
67impl GmpClassGroup {
68    pub fn into_raw(self) -> (Mpz, Mpz) {
69        (self.a, self.b)
70    }
71
72    fn inner_multiply(&mut self, rhs: &Self, ctx: &mut Ctx) {
73        self.assert_valid();
74        rhs.assert_valid();
75
76        // g = (b1 + b2) / 2
77        ffi::mpz_add(&mut ctx.congruence_context.g, &self.b, &rhs.b);
78        ffi::mpz_fdiv_q_ui_self(&mut ctx.congruence_context.g, 2);
79
80        // h = (b2 - b1) / 2
81        ffi::mpz_sub(&mut ctx.h, &rhs.b, &self.b);
82        ffi::mpz_fdiv_q_ui_self(&mut ctx.h, 2);
83
84        debug_assert!(&ctx.h + &ctx.congruence_context.g == rhs.b);
85        debug_assert!(&ctx.congruence_context.g - &ctx.h == self.b);
86
87        // w = gcd(a1, a2, g)
88        ffi::three_gcd(&mut ctx.w, &self.a, &rhs.a, &ctx.congruence_context.g);
89
90        // j = w
91        ctx.j.set(&ctx.w);
92
93        // s = a1/w
94        ffi::mpz_fdiv_q(&mut ctx.s, &self.a, &ctx.w);
95
96        // t = a2/w
97        ffi::mpz_fdiv_q(&mut ctx.t, &rhs.a, &ctx.w);
98
99        // u = g/w
100        ffi::mpz_fdiv_q(&mut ctx.u, &ctx.congruence_context.g, &ctx.w);
101
102        // a = t*u
103        ffi::mpz_mul(&mut ctx.a, &ctx.t, &ctx.u);
104
105        // b = h*u - s*c1
106        ffi::mpz_mul(&mut ctx.b, &ctx.h, &ctx.u);
107        ffi::mpz_mul(&mut ctx.m, &ctx.s, &self.c);
108        ctx.b += &ctx.m;
109
110        // m = s*t
111        ffi::mpz_mul(&mut ctx.m, &ctx.s, &ctx.t);
112        ctx.congruence_context.solve_linear_congruence(
113            &mut ctx.mu,
114            Some(&mut ctx.v),
115            &ctx.a,
116            &ctx.b,
117            &ctx.m,
118        );
119
120        // a = t*v
121        ffi::mpz_mul(&mut ctx.a, &ctx.t, &ctx.v);
122
123        // b = h - t * mu
124        ffi::mpz_mul(&mut ctx.m, &ctx.t, &ctx.mu);
125        ffi::mpz_sub(&mut ctx.b, &ctx.h, &ctx.m);
126
127        // m = s
128        ctx.m.set(&ctx.s);
129
130        ctx.congruence_context.solve_linear_congruence(
131            &mut ctx.lambda,
132            Some(&mut ctx.sigma),
133            &ctx.a,
134            &ctx.b,
135            &ctx.m,
136        );
137
138        // k = mu + v*lambda
139        ffi::mpz_mul(&mut ctx.a, &ctx.v, &ctx.lambda);
140        ffi::mpz_add(&mut ctx.k, &ctx.mu, &ctx.a);
141
142        // l = (k*t - h)/s
143        ffi::mpz_mul(&mut ctx.l, &ctx.k, &ctx.t);
144        ffi::mpz_sub(&mut ctx.v, &ctx.l, &ctx.h);
145        ffi::mpz_fdiv_q(&mut ctx.l, &ctx.v, &ctx.s);
146
147        // m = (t*u*k - h*u - c*s) / s*t
148        ffi::mpz_mul(&mut ctx.m, &ctx.t, &ctx.u);
149        ctx.m *= &ctx.k;
150        ffi::mpz_mul(&mut ctx.a, &ctx.h, &ctx.u);
151        ctx.m -= &ctx.a;
152        ffi::mpz_mul(&mut ctx.a, &self.c, &ctx.s);
153        ctx.m -= &ctx.a;
154        ffi::mpz_mul(&mut ctx.a, &ctx.s, &ctx.t);
155        ffi::mpz_fdiv_q(&mut ctx.lambda, &ctx.m, &ctx.a);
156
157        // A = s*t - r*u
158        ffi::mpz_mul(&mut self.a, &ctx.s, &ctx.t);
159
160        // B = ju + mr - (kt + ls)
161        ffi::mpz_mul(&mut self.b, &ctx.j, &ctx.u);
162        ffi::mpz_mul(&mut ctx.a, &ctx.k, &ctx.t);
163        self.b -= &ctx.a;
164        ffi::mpz_mul(&mut ctx.a, &ctx.l, &ctx.s);
165        self.b -= &ctx.a;
166
167        // C = kl - jm
168        ffi::mpz_mul(&mut self.c, &ctx.k, &ctx.l);
169        ffi::mpz_mul(&mut ctx.a, &ctx.j, &ctx.lambda);
170        self.c -= &ctx.a;
171
172        self.inner_reduce(ctx);
173    }
174
175    #[cfg_attr(not(debug_assertions), inline(always))]
176    fn new(a: Mpz, b: Mpz, c: Mpz, discriminant: Mpz) -> Self {
177        let s = GmpClassGroup {
178            a,
179            b,
180            c,
181            discriminant,
182        };
183        s.assert_valid();
184        s
185    }
186
187    #[cfg_attr(not(debug_assertions), inline(always))]
188    fn assert_valid(&self) {
189        if cfg!(debug_assertions) {
190            let four: Mpz = 4u64.into();
191            let four_ac: Mpz = four * &self.a * &self.c;
192            assert!(&self.discriminant + four_ac == &self.b * &self.b);
193        }
194    }
195
196    fn inner_normalize(&mut self, ctx: &mut Ctx) {
197        self.assert_valid();
198        ctx.negative_a = -&self.a;
199        if self.b > ctx.negative_a && self.b <= self.a {
200            return;
201        }
202        ffi::mpz_sub(&mut ctx.r, &self.a, &self.b);
203        ffi::mpz_mul_2exp(&mut ctx.denom, &self.a, 1);
204        ffi::mpz_fdiv_q(&mut ctx.negative_a, &ctx.r, &ctx.denom);
205        swap(&mut ctx.negative_a, &mut ctx.r);
206        swap(&mut ctx.old_b, &mut self.b);
207        ffi::mpz_mul(&mut ctx.ra, &ctx.r, &self.a);
208        ffi::mpz_mul_2exp(&mut ctx.negative_a, &ctx.ra, 1);
209        ffi::mpz_add(&mut self.b, &ctx.old_b, &ctx.negative_a);
210
211        ffi::mpz_mul(&mut ctx.negative_a, &ctx.ra, &ctx.r);
212        ffi::mpz_add(&mut ctx.old_a, &self.c, &ctx.negative_a);
213
214        ffi::mpz_mul(&mut ctx.ra, &ctx.r, &ctx.old_b);
215        ffi::mpz_add(&mut self.c, &ctx.old_a, &ctx.ra);
216
217        self.assert_valid();
218    }
219
220    fn inner_reduce(&mut self, ctx: &mut Ctx) {
221        self.inner_normalize(ctx);
222
223        while if ffi::mpz_is_negative(&self.b) {
224            self.a >= self.c
225        } else {
226            self.a > self.c
227        } {
228            debug_assert!(!self.c.is_zero());
229            ffi::mpz_add(&mut ctx.s, &self.c, &self.b);
230            ffi::mpz_add(&mut ctx.x, &self.c, &self.c);
231            swap(&mut self.b, &mut ctx.old_b);
232            ffi::mpz_fdiv_q(&mut self.b, &ctx.s, &ctx.x);
233            swap(&mut self.b, &mut ctx.s);
234            swap(&mut self.a, &mut self.c);
235
236            // x = 2sc
237            ffi::mpz_mul(&mut self.b, &ctx.s, &self.a);
238            ffi::mpz_mul_2exp(&mut ctx.x, &self.b, 1);
239
240            // b = x - old_b
241            ffi::mpz_sub(&mut self.b, &ctx.x, &ctx.old_b);
242
243            // x = b*s
244            ffi::mpz_mul(&mut ctx.x, &ctx.old_b, &ctx.s);
245
246            // s = c*s^2
247            ffi::mpz_mul(&mut ctx.old_b, &ctx.s, &ctx.s);
248            ffi::mpz_mul(&mut ctx.s, &self.a, &ctx.old_b);
249
250            // c = s - x
251            ffi::mpz_sub(&mut ctx.old_a, &ctx.s, &ctx.x);
252
253            // c += a
254            self.c += &ctx.old_a;
255        }
256        self.inner_normalize(ctx);
257    }
258
259    fn inner_square_impl(&mut self, ctx: &mut Ctx) {
260        self.assert_valid();
261        ctx.congruence_context.solve_linear_congruence(
262            &mut ctx.mu,
263            None,
264            &self.b,
265            &self.c,
266            &self.a,
267        );
268        ffi::mpz_mul(&mut ctx.m, &self.b, &ctx.mu);
269        ctx.m -= &self.c;
270        ctx.m = ctx.m.div_floor(&self.a);
271
272        // New a
273        ctx.old_a.set(&self.a);
274        ffi::mpz_mul(&mut self.a, &ctx.old_a, &ctx.old_a);
275
276        // New b
277        ffi::mpz_mul(&mut ctx.a, &ctx.mu, &ctx.old_a);
278        ffi::mpz_double(&mut ctx.a);
279        self.b -= &ctx.a;
280
281        // New c
282        ffi::mpz_mul(&mut self.c, &ctx.mu, &ctx.mu);
283        self.c -= &ctx.m;
284        self.inner_reduce(ctx);
285    }
286
287    #[cfg_attr(not(debug_assertions), inline(always))]
288    fn inner_square(&mut self, ctx: &mut Ctx) {
289        if cfg!(debug_assertions) {
290            let mut q = self.clone();
291            q.inner_multiply(self, ctx);
292            self.inner_square_impl(ctx);
293            assert_eq!(*self, q);
294        } else {
295            self.inner_square_impl(ctx);
296        }
297    }
298
299    /// Call `cb` with a mutable reference to the context of type `Ctx`.
300    ///
301    /// The reference cannot escape the closure and cannot be sent across
302    /// threads.
303    ///
304    /// # Panics
305    ///
306    /// Panics if called recursively.  This library guarantees that it will
307    /// never call this function from any function that takes a parameter of
308    /// type `&mut Ctx`.
309    pub fn with_context<T, U>(cb: T) -> U
310    where
311        T: FnOnce(&mut Ctx) -> U,
312    {
313        let mut opt = None;
314        CTX.with(|x| opt = Some(cb(&mut x.borrow_mut())));
315        opt.unwrap()
316    }
317}
318
319impl Default for GmpClassGroup {
320    fn default() -> Self {
321        GmpClassGroup {
322            a: Mpz::new(),
323            b: Mpz::new(),
324            c: Mpz::new(),
325            discriminant: Mpz::new(),
326        }
327    }
328}
329
330impl<B: Borrow<GmpClassGroup>> MulAssign<B> for GmpClassGroup {
331    #[cfg_attr(not(debug_assertions), inline(always))]
332    fn mul_assign(&mut self, rhs: B) {
333        let rhs = rhs.borrow();
334        debug_assert!(self.discriminant == rhs.discriminant);
335        GmpClassGroup::with_context(|ctx| self.inner_multiply(rhs, ctx));
336    }
337}
338
339impl super::BigNum for Mpz {
340    fn probab_prime(&self, iterations: u32) -> bool {
341        self.probab_prime(iterations.max(256) as _) != NotPrime
342    }
343
344    fn setbit(&mut self, bit_index: usize) {
345        self.setbit(bit_index)
346    }
347
348    fn mod_powm(&mut self, base: &Self, exponent: &Self, modulus: &Self) {
349        ffi::mpz_powm(self, base, exponent, modulus)
350    }
351}
352
353impl super::BigNumExt for Mpz {
354    fn frem_u32(&self, modulus: u32) -> u32 {
355        ffi::mpz_frem_u32(self, modulus)
356    }
357    fn crem_u16(&mut self, modulus: u16) -> u16 {
358        ffi::mpz_crem_u16(self, modulus)
359    }
360}
361
362impl<B: Borrow<Self>> Mul<B> for GmpClassGroup {
363    type Output = Self;
364    #[inline]
365    fn mul(mut self, rhs: B) -> Self {
366        self *= rhs.borrow();
367        self
368    }
369}
370
371impl<'a, B: Borrow<GmpClassGroup>> Mul<B> for &'a GmpClassGroup {
372    type Output = GmpClassGroup;
373
374    #[inline(always)]
375    fn mul(self, rhs: B) -> Self::Output {
376        let mut s = Clone::clone(self);
377        s *= rhs;
378        s
379    }
380}
381
382impl ClassGroup for GmpClassGroup {
383    type BigNum = Mpz;
384
385    /// Normalize `self`.
386    ///
387    /// # Panics
388    ///
389    /// Panics if called within a call to `Self::with_context`.
390    fn normalize(&mut self) {
391        Self::with_context(|x| self.inner_normalize(x))
392    }
393
394    #[cfg_attr(not(debug_assertions), inline(always))]
395    fn inverse(&mut self) {
396        self.assert_valid();
397        self.b = -self.b.clone();
398    }
399
400    fn serialize(&self, buf: &mut [u8]) -> Result<(), usize> {
401        self.assert_valid();
402        if buf.len() & 1 == 1 {
403            // odd lengths do not make sense
404            Err(0)
405        } else {
406            let len = buf.len() >> 1;
407            ffi::export_obj(&self.a, &mut buf[..len])?;
408            ffi::export_obj(&self.b, &mut buf[len..])
409        }
410    }
411
412    fn from_bytes(bytearray: &[u8], discriminant: Self::BigNum) -> Self {
413        let len = (ffi::size_in_bits(&discriminant) + 16) >> 4;
414        let a = ffi::import_obj(&bytearray[..len]);
415        let b = ffi::import_obj(&bytearray[len..]);
416        Self::from_ab_discriminant(a, b, discriminant)
417    }
418
419    fn from_ab_discriminant(a: Self::BigNum, b: Self::BigNum, discriminant: Self::BigNum) -> Self {
420        let mut four_a: Self::BigNum = 4u64.into();
421        four_a *= &a;
422        let c = (&b * &b - &discriminant) / four_a;
423        Self {
424            a,
425            b,
426            c,
427            discriminant,
428        }
429    }
430
431    /// Returns the discriminant of `self`.
432    #[inline(always)]
433    fn discriminant(&self) -> &Self::BigNum {
434        &self.discriminant
435    }
436
437    fn size_in_bits(num: &Self::BigNum) -> usize {
438        ffi::size_in_bits(num)
439    }
440
441    /// Reduce `self`.
442    ///
443    /// # Panics
444    ///
445    /// Panics if called within a call to `Self::with_context`.
446    fn reduce(&mut self) {
447        Self::with_context(|x| self.inner_reduce(x))
448    }
449
450    fn deserialize(buf: &[u8], discriminant: Self::BigNum) -> Self {
451        let len = buf.len();
452        debug_assert!(len != 0, "Cannot deserialize an empty buffer!");
453        debug_assert!(len & 1 == 0, "Buffer must be of even length");
454        let half_len = len >> 1;
455        Self::from_ab_discriminant(
456            ffi::import_obj(&buf[..half_len]),
457            ffi::import_obj(&buf[half_len..]),
458            discriminant,
459        )
460    }
461
462    /// Square `self`.ClassGroupPartial
463    ///
464    /// # Panics
465    ///
466    /// Panics if called within the scope of a call to `with_context`.
467    fn square(&mut self) {
468        Self::with_context(|ctx| self.inner_square(ctx))
469    }
470
471    fn unsigned_deserialize_bignum(buf: &[u8]) -> Self::BigNum {
472        buf.into()
473    }
474
475    /// Square `self` `iterations` times.
476    ///
477    /// # Panics
478    ///
479    /// Panics if called within the scope of a call to `with_context`.
480    fn repeated_square(&mut self, iterations: u64) {
481        Self::with_context(|ctx| {
482            for _ in 0..iterations {
483                self.inner_square(ctx)
484            }
485        })
486    }
487
488    fn generator_for_discriminant(discriminant: Self::BigNum) -> Self {
489        let one: Mpz = One::one();
490        let x: Mpz = &one - &discriminant;
491        let mut form = Self::new(2.into(), one, x.div_floor(&8.into()), discriminant);
492        form.reduce();
493        form
494    }
495
496    fn pow(&mut self, mut exponent: Mpz) {
497        self.assert_valid();
498        debug_assert!(exponent >= Mpz::zero());
499        let mut state = self.identity();
500        loop {
501            let is_odd = exponent.tstbit(0);
502            exponent >>= 1;
503            if is_odd {
504                state *= &*self
505            }
506            if exponent.is_zero() {
507                *self = state;
508                break;
509            }
510            self.square();
511        }
512    }
513}
514
515impl Default for Ctx {
516    fn default() -> Self {
517        Self {
518            negative_a: Mpz::new(),
519            r: Mpz::new(),
520            denom: Mpz::new(),
521            old_a: Mpz::new(),
522            old_b: Mpz::new(),
523            ra: Mpz::new(),
524            s: Mpz::new(),
525            x: Mpz::new(),
526            congruence_context: Default::default(),
527            w: Mpz::new(),
528            m: Mpz::new(),
529            u: Mpz::new(),
530            l: Mpz::new(),
531            j: Mpz::new(),
532            t: Mpz::new(),
533            a: Mpz::new(),
534            b: Mpz::new(),
535            k: Mpz::new(),
536            h: Mpz::new(),
537            mu: Mpz::new(),
538            v: Mpz::new(),
539            sigma: Mpz::new(),
540            lambda: Mpz::new(),
541        }
542    }
543}
544
545pub fn do_compute(discriminant: Mpz, iterations: u64) -> GmpClassGroup {
546    debug_assert!(discriminant < Zero::zero());
547    debug_assert!(discriminant.probab_prime(50) != NotPrime);
548    let mut f = GmpClassGroup::generator_for_discriminant(discriminant);
549    f.repeated_square(iterations);
550    f
551}
552
553#[cfg(test)]
554mod test {
555    #![allow(unused_imports)]
556    use super::*;
557    #[test]
558    fn normalize() {
559        let mut s = GmpClassGroup::new(
560            16.into(),
561            (-23).into(),
562            5837_3892.into(),
563            (-0xdead_beefi64).into(),
564        );
565        let mut new = GmpClassGroup {
566            b: 9.into(),
567            c: 5837_3885.into(),
568            ..s.clone()
569        };
570        s.normalize();
571        assert_eq!(s, new);
572
573        s = GmpClassGroup {
574            a: (1 << 16).into(),
575            b: (-76951).into(),
576            c: 36840.into(),
577            ..s
578        };
579        new = GmpClassGroup {
580            b: 54121.into(),
581            c: 25425.into(),
582            ..s.clone()
583        };
584        s.normalize();
585        assert_eq!(s, new);
586    }
587}