Skip to main content

vdf_classgroup/gmp_classgroup/
ffi.rs

1// Copyright 2018 POA Networks Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! FFI bindings to GMP.  This module exists because the `rust-gmp` crate
16//! is too high-level.  High-performance bignum computation requires that
17//! bignums be modified in-place, so that their storage can be reused.
18//! Furthermore, the `rust-gmp` crate doesn’t support many operations that
19//! this library requires.
20#![allow(unsafe_code)]
21pub use super::super::gmp::mpz::Mpz;
22use super::super::gmp::mpz::{mp_bitcnt_t, mp_limb_t};
23use libc::{c_int, c_long, c_ulong, c_void, size_t};
24// pub use c_ulong;
25use std::mem::MaybeUninit;
26use std::usize;
27// We use the unsafe versions to avoid unecessary allocations.
28#[link(name = "gmp")]
29extern "C" {
30    fn __gmpz_gcdext(gcd: *mut Mpz, s: *mut Mpz, t: *mut Mpz, a: *const Mpz, b: *const Mpz);
31    fn __gmpz_gcd(rop: *mut Mpz, op1: *const Mpz, op2: *const Mpz);
32    fn __gmpz_fdiv_qr(q: *mut Mpz, r: *mut Mpz, b: *const Mpz, g: *const Mpz);
33    fn __gmpz_fdiv_q(q: *mut Mpz, n: *const Mpz, d: *const Mpz);
34    fn __gmpz_divexact(q: *mut Mpz, n: *const Mpz, d: *const Mpz);
35    fn __gmpz_tdiv_q(q: *mut Mpz, n: *const Mpz, d: *const Mpz);
36    fn __gmpz_mul(p: *mut Mpz, a: *const Mpz, b: *const Mpz);
37    fn __gmpz_mul_2exp(rop: *mut Mpz, op1: *const Mpz, op2: mp_bitcnt_t);
38    fn __gmpz_sub(rop: *mut Mpz, op1: *const Mpz, op2: *const Mpz);
39    fn __gmpz_import(
40        rop: *mut Mpz,
41        count: size_t,
42        order: c_int,
43        size: size_t,
44        endian: c_int,
45        nails: size_t,
46        op: *const c_void,
47    );
48    fn __gmpz_tdiv_r(r: *mut Mpz, n: *const Mpz, d: *const Mpz);
49    fn __gmpz_sizeinbase(op: &Mpz, base: c_int) -> size_t;
50    fn __gmpz_fdiv_q_ui(rop: *mut Mpz, op1: *const Mpz, op2: c_ulong) -> c_ulong;
51    fn __gmpz_add(rop: *mut Mpz, op1: *const Mpz, op2: *const Mpz);
52    fn __gmpz_add_ui(rop: *mut Mpz, op1: *const Mpz, op2: c_ulong);
53    fn __gmpz_set_ui(rop: &mut Mpz, op: c_ulong);
54    fn __gmpz_set_si(rop: &mut Mpz, op: c_long);
55    fn __gmpz_cdiv_ui(n: &Mpz, d: c_ulong) -> c_ulong;
56    fn __gmpz_fdiv_ui(n: &Mpz, d: c_ulong) -> c_ulong;
57    fn __gmpz_tdiv_ui(n: &Mpz, d: c_ulong) -> c_ulong;
58    fn __gmpz_export(
59        rop: *mut c_void,
60        countp: *mut size_t,
61        order: c_int,
62        size: size_t,
63        endian: c_int,
64        nails: size_t,
65        op: &Mpz,
66    ) -> *mut c_void;
67    fn __gmpz_powm(rop: *mut Mpz, base: *const Mpz, exp: *const Mpz, modulus: *const Mpz);
68}
69
70// MEGA HACK: rust-gmp doesn’t expose the fields of this struct, so we must define
71// it ourselves and cast.
72//
73// Should be stable though, as only GMP can change it, and doing would break binary compatibility.
74#[repr(C)]
75struct MpzStruct {
76    mp_alloc: c_int,
77    mp_size: c_int,
78    mp_d: *mut mp_limb_t,
79}
80
81macro_rules! impl_div_ui {
82    ($t:ident, $i:ident, $f:expr) => {
83        pub fn $i(n: &Mpz, d: $t) -> $t {
84            use std::$t;
85            let res = unsafe { $f(n, c_ulong::from(d)) };
86            assert!(res <= $t::MAX.into());
87            res as $t
88        }
89    };
90}
91
92impl_div_ui!(u16, mpz_crem_u16, __gmpz_cdiv_ui);
93impl_div_ui!(u32, mpz_frem_u32, __gmpz_fdiv_ui);
94
95/// Returns `true` if `z` is negative and not zero.  Otherwise,
96/// returns `false`.
97#[inline]
98pub fn mpz_is_negative(z: &Mpz) -> bool {
99    unsafe { (*(z as *const _ as *const MpzStruct)).mp_size < 0 }
100}
101
102#[inline]
103pub fn mpz_powm(rop: &mut Mpz, base: &Mpz, exponent: &Mpz, modulus: &Mpz) {
104    unsafe { __gmpz_powm(rop, base, exponent, modulus) }
105}
106
107#[inline]
108pub fn mpz_tdiv_r(r: &mut Mpz, n: &Mpz, d: &Mpz) {
109    unsafe { __gmpz_tdiv_r(r, n, d) }
110}
111
112/// Sets `g` to the GCD of `a` and `b`.
113#[inline]
114pub fn mpz_gcdext(gcd: &mut Mpz, s: &mut Mpz, t: &mut Mpz, a: &Mpz, b: &Mpz) {
115    unsafe { __gmpz_gcdext(gcd, s, t, a, b) }
116}
117
118/// Doubles `rop` in-place
119#[inline]
120pub fn mpz_double(rop: &mut Mpz) {
121    if true {
122        // slightly faster
123        unsafe { __gmpz_mul_2exp(rop, rop, 1) }
124    } else {
125        unsafe { __gmpz_add(rop, rop, rop) }
126    }
127}
128
129#[inline]
130pub fn mpz_fdiv_qr(q: &mut Mpz, r: &mut Mpz, b: &Mpz, g: &Mpz) {
131    unsafe { __gmpz_fdiv_qr(q, r, b, g) }
132}
133
134#[inline]
135pub fn mpz_fdiv_q_ui_self(rop: &mut Mpz, op: c_ulong) -> c_ulong {
136    unsafe { __gmpz_fdiv_q_ui(rop, rop, op) }
137}
138
139/// Unmarshals a buffer to an `Mpz`.  `buf` is interpreted as a 2’s complement,
140/// big-endian integer.  If the buffer is empty, zero is returned.
141pub fn import_obj(buf: &[u8]) -> Mpz {
142    fn raw_import(buf: &[u8]) -> Mpz {
143        let mut obj = Mpz::new();
144
145        unsafe { __gmpz_import(&mut obj, buf.len(), 1, 1, 1, 0, buf.as_ptr() as *const _) }
146        obj
147    }
148    let is_negative = match buf.first() {
149        None => return Mpz::zero(),
150        Some(x) => x & 0x80 != 0,
151    };
152    if !is_negative {
153        raw_import(buf)
154    } else {
155        let mut new_buf: Vec<_> = buf.iter().cloned().skip_while(|&x| x == 0xFF).collect();
156        if new_buf.is_empty() {
157            (-1).into()
158        } else {
159            for i in &mut new_buf {
160                *i ^= 0xFF
161            }
162            !raw_import(&new_buf)
163        }
164    }
165}
166
167pub fn three_gcd(rop: &mut Mpz, a: &Mpz, b: &Mpz, c: &Mpz) {
168    unsafe {
169        __gmpz_gcd(rop, a, b);
170        __gmpz_gcd(rop, rop, c)
171    }
172}
173
174#[inline]
175pub fn size_in_bits(obj: &Mpz) -> usize {
176    unsafe { __gmpz_sizeinbase(obj, 2) }
177}
178
179#[inline]
180pub fn mpz_add(rop: &mut Mpz, op1: &Mpz, op2: &Mpz) {
181    unsafe { __gmpz_add(rop, op1, op2) }
182}
183
184#[inline]
185pub fn mpz_mul(rop: &mut Mpz, op1: &Mpz, op2: &Mpz) {
186    unsafe { __gmpz_mul(rop, op1, op2) }
187}
188
189#[inline]
190pub fn mpz_divexact(q: &mut Mpz, n: &Mpz, d: &Mpz) {
191    unsafe { __gmpz_divexact(q, n, d) }
192}
193
194#[inline]
195pub fn mpz_mul_2exp(rop: &mut Mpz, op1: &Mpz, op2: mp_bitcnt_t) {
196    unsafe { __gmpz_mul_2exp(rop as *mut _ as *mut Mpz, op1, op2) }
197}
198
199/// Divide `n` by `d`.  Round towards -∞ and place the result in `q`.
200#[inline]
201pub fn mpz_fdiv_q(q: &mut Mpz, n: &Mpz, d: &Mpz) {
202    if mpz_is_negative(n) == mpz_is_negative(d) {
203        unsafe { __gmpz_tdiv_q(q, n, d) }
204    } else {
205        unsafe { __gmpz_fdiv_q(q, n, d) }
206    }
207}
208
209/// Sets `rop` to `(-1) * op`
210#[inline]
211#[cfg(any())]
212pub fn mpz_neg(rop: &mut Mpz) {
213    assert!(std::mem::size_of::<Mpz>() == std::mem::size_of::<MpzStruct>());
214    unsafe {
215        let ptr = rop as *mut _ as *mut MpzStruct;
216        let v = (*ptr).mp_size;
217        (*ptr).mp_size = -v;
218    }
219}
220
221/// Subtracts `op2` from `op1` and stores the result in `rop`.
222#[inline]
223pub fn mpz_sub(rop: &mut Mpz, op1: &Mpz, op2: &Mpz) {
224    unsafe { __gmpz_sub(rop as *mut _ as *mut Mpz, op1, op2) }
225}
226
227/// Exports `obj` to `v` as an array of 2’s complement, big-endian
228/// bytes.  If `v` is too small to hold the result, returns `Err(s)`,
229/// where `s` is the size needed to hold the exported version of `obj`.
230pub fn export_obj(obj: &Mpz, v: &mut [u8]) -> Result<(), usize> {
231    // Requires: offset < v.len() and v[offset..] be able to hold all of `obj`
232    unsafe fn raw_export(v: &mut [u8], offset: usize, obj: &Mpz) -> usize {
233        // SAFE as `offset` will always be in-bounds, since byte_len always <=
234        // byte_len_needed and we check that v.len() >= byte_len_needed.
235        let ptr = v.as_mut_ptr().add(offset) as *mut c_void;
236
237        // Necessary ― this byte may not be fully overwritten
238        *(ptr as *mut u8) = 0;
239
240        // SAFE as __gmpz_export will *always* initialize this.
241        let mut s = MaybeUninit::<usize>::uninit();
242        let ptr2 = __gmpz_export(ptr, s.as_mut_ptr(), 1, 1, 1, 0, obj);
243        assert_eq!(ptr, ptr2);
244        let s = unsafe { s.assume_init() };
245        if 0 == s {
246            1
247        } else {
248            s
249        }
250    }
251
252    let size = size_in_bits(obj);
253    assert!(size > 0);
254
255    // Check to avoid integer overflow in later operations.
256    if size > usize::MAX - 8 || v.len() > usize::MAX >> 3 {
257        return Err(usize::MAX);
258    }
259
260    // One additional bit is needed for the sign bit.
261    let byte_len_needed = (size + 8) >> 3;
262    if v.len() < byte_len_needed {
263        return if v.is_empty() && obj.is_zero() {
264            Ok(())
265        } else {
266            Err(byte_len_needed)
267        };
268    }
269    let is_negative = mpz_is_negative(obj);
270
271    if is_negative {
272        // MEGA HACK: GMP does not have a function to perform 2's complement
273        let obj = !obj;
274        debug_assert!(
275            !mpz_is_negative(&obj),
276            "bitwise negation of a negative number produced a negative number"
277        );
278        let new_byte_size = (size_in_bits(&obj) + 7) >> 3;
279        let offset = v.len() - new_byte_size;
280
281        for i in &mut v[..offset] {
282            *i = 0xFF
283        }
284        unsafe {
285            assert_eq!(raw_export(v, offset, &obj), new_byte_size);
286        }
287
288        // We had to do a one’s complement to get the data in a decent format,
289        // so now we need to flip all of the bits back.  LLVM should be able to
290        // vectorize this loop easily.
291        for i in &mut v[offset..] {
292            *i ^= 0xFF
293        }
294    } else {
295        // ...but GMP will not include that in the number of bytes it writes
296        // (except for negative numbers)
297        let byte_len = (size + 7) >> 3;
298        assert!(byte_len > 0);
299
300        let offset = v.len() - byte_len;
301
302        // Zero out any leading bytes
303        for i in &mut v[..offset] {
304            *i = 0
305        }
306        unsafe {
307            assert_eq!(raw_export(v, offset, &obj), byte_len);
308        }
309    }
310
311    Ok(())
312}
313
314#[cfg(test)]
315mod test {
316    use super::*;
317    #[test]
318    fn check_expected_bit_width() {
319        let mut s: Mpz = (-2).into();
320        assert_eq!(size_in_bits(&s), 2);
321        s = !s;
322        assert_eq!(s, 1.into());
323        s.setbit(2);
324        assert_eq!(s, 5.into());
325    }
326
327    #[test]
328    fn check_export() {
329        let mut s: Mpz = 0x100.into();
330        s = !s;
331        let mut buf = [0, 0, 0];
332        export_obj(&s, &mut buf).expect("buffer should be large enough");
333        assert_eq!(buf, [0xFF, 0xFE, 0xFF]);
334        export_obj(&Mpz::zero(), &mut []).unwrap();
335    }
336
337    #[test]
338    fn check_rem() {
339        assert_eq!(mpz_crem_u16(&(-100i64).into(), 3), 1);
340        assert_eq!(mpz_crem_u16(&(100i64).into(), 3), 2);
341    }
342}