simplerandom/
maths.rs

1use num_traits::{
2    ConstOne, ConstZero, NumCast, PrimInt, Signed, Unsigned, WrappingAdd, WrappingMul, WrappingNeg,
3    WrappingSub,
4};
5use std::ops::{AddAssign, BitAnd, MulAssign};
6
7pub const fn size_of_bits<T>() -> usize {
8    std::mem::size_of::<T>() * 8
9}
10
11pub fn bit_width_mask<T>(bit_width: usize) -> T
12where
13    T: PrimInt + ConstOne,
14{
15    if bit_width < size_of_bits::<T>() {
16        (T::ONE << bit_width) - T::ONE
17    } else {
18        !(T::ONE)
19    }
20}
21
22/// Unsigned integer types
23///
24pub trait UIntTypes:
25    PrimInt + Unsigned + ConstOne + ConstZero + WrappingAdd + WrappingSub + Copy
26{
27    /// Multiply unsigned `a` and `b`, modulo `m`
28    ///
29    /// It can be specialised for each integer type. See the comments on
30    /// the generic mul_mod() implementation below.
31    ///
32    fn mul_mod(a: Self, b: Self, m: Self) -> Self;
33}
34
35/// Multiply unsigned `a` and `b`, modulo `m`
36///
37/// This is a generic implementation that should work for any unsigned
38/// primitive integer type.
39///
40/// Note that for most integer types for which a larger integer type is
41/// available for intermediate calculations, it may be faster to implement it
42/// by doing a simple multiplication in the larger integer type, calculating
43/// the modulo, then casting back to the result type. That can be done for
44/// u8 through u64, but not u128 and perhaps not usize. See
45/// UIntTypes::mul_mod() which can be specialised for each type.
46///
47/// # Arguments
48///
49/// Multiplicands `a` and `b` can be any unsigned primitive integer.
50/// Modulus `m` can be any unsigned primitive integer.
51///
52/// # Return
53///
54/// The result is the multiplication of `a` and `b`, modulo `m`.
55///
56///     use simplerandom::maths::mul_mod;
57///     let result = mul_mod(123456789_u32, 3111222333, 0x9068FFFF);
58///     assert_eq!(1473911797_u32, result);
59///     let result = mul_mod(12345678901234567890_u64, 10222333444555666777, 0x29A65EACFFFFFFFF);
60///     assert_eq!(1000040008665797219_u64, result);
61///
62pub fn mul_mod_generic<T>(a: T, b: T, m: T) -> T
63where
64    T: UIntTypes,
65{
66    let mut a_work: T = a;
67    let mut b_work: T = b;
68    let mut result: T = T::ZERO;
69
70    if b_work >= m {
71        if m > T::max_value() / (T::ONE + T::ONE) {
72            b_work = b_work.wrapping_sub(&m);
73        } else {
74            b_work = b_work % m;
75        }
76    }
77
78    while a_work != T::ZERO {
79        if a_work & T::ONE != T::ZERO {
80            if b_work >= m - result {
81                result = result.wrapping_sub(&m);
82            }
83            result = result.wrapping_add(&b_work);
84        }
85        a_work = a_work >> 1;
86
87        let mut temp_b = b_work;
88        if b_work >= m - temp_b {
89            temp_b = temp_b.wrapping_sub(&m);
90        }
91        b_work = b_work.wrapping_add(&temp_b);
92    }
93    result
94}
95
96impl UIntTypes for u8 {
97    /// Simple specialisation using the next larger integer type
98    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
99        ((a as u16) * (b as u16) % (m as u16)) as u8
100    }
101}
102impl UIntTypes for u16 {
103    /// Simple specialisation using the next larger integer type
104    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
105        ((a as u32) * (b as u32) % (m as u32)) as u16
106    }
107}
108impl UIntTypes for u32 {
109    /// Simple specialisation using the next larger integer type
110    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
111        ((a as u64) * (b as u64) % (m as u64)) as u32
112    }
113}
114impl UIntTypes for u64 {
115    /// Simple specialisation using the next larger integer type
116    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
117        ((a as u128) * (b as u128) % (m as u128)) as u64
118    }
119}
120impl UIntTypes for u128 {
121    /// Use the generic implementation
122    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
123        mul_mod_generic::<Self>(a, b, m)
124    }
125}
126impl UIntTypes for usize {
127    /// Use the generic implementation
128    fn mul_mod(a: Self, b: Self, m: Self) -> Self {
129        mul_mod_generic::<Self>(a, b, m)
130    }
131}
132
133pub fn mul_mod<T>(a: T, b: T, m: T) -> T
134where
135    T: UIntTypes,
136{
137    T::mul_mod(a, b, m)
138}
139
140/// Primitive integer types
141///
142/// Mappings to associated signed and unsigned types with the same bit width.
143///
144pub trait IntTypes:
145    PrimInt + NumCast + ConstZero + ConstOne + WrappingAdd + WrappingNeg + Copy
146{
147    type SignedType: PrimInt + Signed + ConstZero + ConstOne + Copy + NumCast;
148    type UnsignedType: PrimInt + Unsigned + ConstZero + ConstOne + Copy + NumCast + WrappingNeg;
149    type OtherSignType: PrimInt + ConstOne + ConstZero + Copy + NumCast;
150
151    /// `abs()` function which returns a corresponding unsigned type
152    ///
153    /// For unsigned input types, just return the same value.
154    /// For signed types, return the unsigned type of the same bit width.
155    ///
156    fn abs_as_unsigned(a: Self) -> Self::UnsignedType;
157}
158
159/// `abs()` function which returns a corresponding unsigned type
160///
161/// For unsigned input types, just return the same value.
162/// For signed types, return the unsigned type of the same bit width.
163///
164/// This is a generic implementation which should work for all primitive
165/// integers, both signed and unsigned.
166pub fn abs_as_unsigned_generic<T>(a: T) -> T::UnsignedType
167where
168    T: IntTypes,
169{
170    if a < T::ZERO {
171        // Negative input. Negate it.
172        let result: Option<T::UnsignedType> = NumCast::from(a.wrapping_neg());
173        if result.is_some() {
174            // The vast majority of values.
175            result.unwrap()
176        } else {
177            // The exceptional case: in two's complement form, the lowest
178            // negative number's negation doesn't fit into the signed type.
179            let result_minus_1: Option<T::UnsignedType> =
180                NumCast::from((a + T::ONE).wrapping_neg());
181            result_minus_1.unwrap_or(T::UnsignedType::ZERO) + T::UnsignedType::ONE
182        }
183    } else {
184        // Positive input. Return it as-is.
185        let result: Option<T::UnsignedType> = NumCast::from(a);
186        result.unwrap_or(T::UnsignedType::ZERO)
187    }
188}
189
190impl IntTypes for i8 {
191    type SignedType = i8;
192    type UnsignedType = u8;
193    type OtherSignType = u8;
194    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
195        abs_as_unsigned_generic::<Self>(a)
196    }
197}
198impl IntTypes for i16 {
199    type SignedType = i16;
200    type UnsignedType = u16;
201    type OtherSignType = u16;
202    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
203        abs_as_unsigned_generic::<Self>(a)
204    }
205}
206impl IntTypes for i32 {
207    type SignedType = i32;
208    type UnsignedType = u32;
209    type OtherSignType = u32;
210    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
211        abs_as_unsigned_generic::<Self>(a)
212    }
213}
214impl IntTypes for i64 {
215    type SignedType = i64;
216    type UnsignedType = u64;
217    type OtherSignType = u64;
218    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
219        abs_as_unsigned_generic::<Self>(a)
220    }
221}
222impl IntTypes for i128 {
223    type SignedType = i128;
224    type UnsignedType = u128;
225    type OtherSignType = u128;
226    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
227        abs_as_unsigned_generic::<Self>(a)
228    }
229}
230impl IntTypes for isize {
231    type SignedType = isize;
232    type UnsignedType = usize;
233    type OtherSignType = usize;
234    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
235        abs_as_unsigned_generic::<Self>(a)
236    }
237}
238impl IntTypes for u8 {
239    type SignedType = i8;
240    type UnsignedType = u8;
241    type OtherSignType = i8;
242    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
243        a
244    }
245}
246impl IntTypes for u16 {
247    type SignedType = i16;
248    type UnsignedType = u16;
249    type OtherSignType = i16;
250    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
251        a
252    }
253}
254impl IntTypes for u32 {
255    type SignedType = i32;
256    type UnsignedType = u32;
257    type OtherSignType = i32;
258    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
259        a
260    }
261}
262impl IntTypes for u64 {
263    type SignedType = i64;
264    type UnsignedType = u64;
265    type OtherSignType = i64;
266    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
267        a
268    }
269}
270impl IntTypes for u128 {
271    type SignedType = i128;
272    type UnsignedType = u128;
273    type OtherSignType = i128;
274    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
275        a
276    }
277}
278impl IntTypes for usize {
279    type SignedType = isize;
280    type UnsignedType = usize;
281    type OtherSignType = isize;
282    fn abs_as_unsigned(a: Self) -> Self::UnsignedType {
283        a
284    }
285}
286
287/// `abs()` function which returns a corresponding unsigned type
288///
289/// For unsigned input types, just return the same value.
290/// For signed types, return the unsigned type of the same bit width.
291///
292/// This is a specialised implementation which uses the [`IntTypes`] trait funtion, and is defined
293/// differently for the signed and unsigned integers.
294pub fn abs_as_unsigned<T>(a: T) -> T::UnsignedType
295where
296    T: IntTypes,
297{
298    IntTypes::abs_as_unsigned(a)
299}
300
301/// Calculate `a` modulo `m`
302///
303/// # Arguments
304///
305/// `a` can be any primitive integer, signed or unsigned.
306/// `m` can be any unsigned primitive integer.
307///
308/// # Return
309///
310/// The result is the same unsigned type as that of parameter `m`.
311/// The result is in the range [0..m] even when `a` is negative.
312///
313///     use simplerandom::maths::modulo;
314///     let result = modulo(12345_u32, 7_u32);
315///     assert_eq!(result, 4_u32);
316///     let result = modulo(-12345_i32, 7_u32);
317///     assert_eq!(result, 3_u32);
318///
319pub fn modulo<A, M>(a: A, m: M) -> M
320where
321    A: IntTypes,
322    M: PrimInt + Unsigned + ConstZero + Copy + NumCast,
323{
324    if a >= A::ZERO {
325        // Positive input.
326        let a_opt: Option<M> = NumCast::from(a);
327        if a_opt.is_some() {
328            // a fits into type M. Easy.
329            a_opt.unwrap() % m
330        } else {
331            // a doesn't fit into type M. m should fit into type A.
332            let m_opt: Option<A> = NumCast::from(m);
333            let result_a = a % m_opt.unwrap();
334            let result_m: Option<M> = NumCast::from(result_a);
335            result_m.unwrap()
336        }
337    } else {
338        // Negative input.
339        let a_abs = abs_as_unsigned(a);
340        let a_abs_opt: Option<M> = NumCast::from(a_abs);
341        if a_abs_opt.is_some() {
342            // a_abs fits into type M.
343            m - (a_abs_opt.unwrap() % m)
344        } else {
345            // a_abs doesn't fit into type M. m should fit into the corresponding unsigned type of A.
346            let m_opt: Option<A::UnsignedType> = NumCast::from(m);
347            let m_s = m_opt.unwrap();
348            let result_a = m_s - (a_abs % m_s);
349            let result_m: Option<M> = NumCast::from(result_a);
350            result_m.unwrap()
351        }
352    }
353}
354
355/// Exponentiation with wrapping
356///
357/// Calculation of `base` to the power of an unsigned integer `n`, with the
358/// natural modulo of the unsigned integer type T (ie, with wrapping).
359///
360///     use simplerandom::maths::wrapping_pow;
361///     let result = wrapping_pow(12345_u32, 1500000_u32);
362///     assert_eq!(result, 2764689665_u32);
363///
364pub fn wrapping_pow<T, N>(base: T, n: N) -> T
365where
366    T: PrimInt + Unsigned + WrappingMul + WrappingSub + ConstOne,
367    N: PrimInt + Unsigned + ConstOne + ConstZero + BitAnd,
368{
369    let mut result: T = T::ONE;
370    let mut temp_exp = base;
371    let mut n_work: N = n;
372
373    loop {
374        if n_work & N::ONE != N::ZERO {
375            result = result.wrapping_mul(&temp_exp);
376        }
377        n_work = n_work >> 1;
378        if n_work == N::ZERO {
379            break;
380        }
381        temp_exp = temp_exp.wrapping_mul(&temp_exp);
382    }
383    result
384}
385
386/// Modular exponentiation
387///
388/// Calculation of `base` to the power of an unsigned integer `n`,
389/// modulo a value `m`.
390///
391///     use simplerandom::maths::pow_mod;
392///     let result = pow_mod(12345_u32, 1500000_u32, 1211400191_u32);
393///     assert_eq!(result, 348133782_u32);
394///     let result = pow_mod(0xDC28D76FFD9338E9D868AF566191DE10_u128,
395///                           0x732E73C316878E244FDFDE4EE623CDCC_u128,
396///                           0xEC327D45470669CC56B547B6FE6888A2_u128);
397///     assert_eq!(result, 0x6AA4E49D8B90A5467A9655090EDD7940_u128);
398pub fn pow_mod<T, N>(base: T, n: N, m: T) -> T
399where
400    T: UIntTypes,
401    N: PrimInt + Unsigned + ConstOne + ConstZero + BitAnd,
402{
403    let mut result: T = T::ONE;
404    let mut temp_exp = base;
405    let mut n_work: N = n;
406
407    loop {
408        if n_work & N::ONE != N::ZERO {
409            result = mul_mod(result, temp_exp, m);
410        }
411        n_work = n_work >> 1;
412        if n_work == N::ZERO {
413            break;
414        }
415        temp_exp = mul_mod(temp_exp, temp_exp, m);
416    }
417    result
418}
419
420/// Calculate geometric series
421///
422/// That is, calculate the geometric series:
423///
424/// 1 + r + r^2 + r^3 + ... r^(n-1)
425///
426/// summed to `n` terms, with the natural modulo of the unsigned integer
427/// type T (ie, with wrapping).
428///
429/// It makes use of the fact that the series can pair up terms:
430///
431/// (1 + r) + (1 + r) r^2 + (1 + r) r^4 + ... + (1 + r) (r^2)^(n/2-1) + [ r^(n-1) if n is odd ]
432/// (1 + r) (1 + r^2 + r^4 + ... + (r^2)^(n/2-1)) + [ r^(n-1) if n is odd ]
433///
434/// Which can easily be calculated by recursion, with time order `O(log n)`, and
435/// stack depth `O(log n)`. However that stack depth isn't good, so a
436/// non-recursive implementation is preferable.
437/// This implementation is by a loop, not recursion, with time order
438/// `O(log n)` and stack depth `O(1)`.
439///
440///     use simplerandom::maths::wrapping_geom_series;
441///     let result = wrapping_geom_series(12345_u32, 1500000_u32);
442///     assert_eq!(result, 57634016_u32);
443///
444pub fn wrapping_geom_series<T, N>(r: T, n: N) -> T
445where
446    T: PrimInt
447        + Unsigned
448        + ConstOne
449        + ConstZero
450        + WrappingMul
451        + WrappingAdd
452        + WrappingSub
453        + AddAssign
454        + MulAssign,
455    N: PrimInt + Unsigned + ConstOne + ConstZero + BitAnd,
456{
457    let mut temp_r = r;
458    let mut mult = T::ONE;
459    let mut result = T::ZERO;
460
461    if n == N::ZERO {
462        return T::ZERO;
463    }
464
465    let mut n_work = n;
466    while n_work > N::ONE {
467        if n_work & N::ONE != N::ZERO {
468            result = wrapping_pow(temp_r, n_work - N::ONE)
469                .wrapping_mul(&mult)
470                .wrapping_add(&result);
471        }
472        mult = (T::ONE.wrapping_add(&temp_r)).wrapping_mul(&mult);
473        temp_r = temp_r.wrapping_mul(&temp_r);
474        n_work = n_work >> 1;
475    }
476    result = result.wrapping_add(&mult);
477    result
478}