Skip to main content

use_modular/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4//! Small modular arithmetic primitives for `RustUse`.
5
6fn checked_modulus(modulus: i64) -> Option<i128> {
7    (modulus > 0).then_some(i128::from(modulus))
8}
9
10fn normalized_i128(value: i64, modulus: i64) -> Option<i128> {
11    let modulus = checked_modulus(modulus)?;
12    Some(i128::from(value).rem_euclid(modulus))
13}
14
15/// Basic modular arithmetic helpers.
16pub mod arithmetic {
17    use crate::{checked_modulus, normalized_i128};
18
19    /// Normalizes `value` into the residue class `0..modulus`.
20    ///
21    /// Returns `None` when `modulus <= 0`.
22    #[must_use]
23    pub fn mod_normalize(value: i64, modulus: i64) -> Option<i64> {
24        i64::try_from(normalized_i128(value, modulus)?).ok()
25    }
26
27    /// Computes `(a + b) mod modulus` and returns the normalized residue.
28    ///
29    /// Returns `None` when `modulus <= 0`.
30    #[must_use]
31    pub fn mod_add(a: i64, b: i64, modulus: i64) -> Option<i64> {
32        let modulus = checked_modulus(modulus)?;
33        let sum = normalized_i128(a, modulus as i64)? + normalized_i128(b, modulus as i64)?;
34        i64::try_from(sum.rem_euclid(modulus)).ok()
35    }
36
37    /// Computes `(a - b) mod modulus` and returns the normalized residue.
38    ///
39    /// Returns `None` when `modulus <= 0`.
40    #[must_use]
41    pub fn mod_sub(a: i64, b: i64, modulus: i64) -> Option<i64> {
42        let modulus = checked_modulus(modulus)?;
43        let difference = normalized_i128(a, modulus as i64)? - normalized_i128(b, modulus as i64)?;
44        i64::try_from(difference.rem_euclid(modulus)).ok()
45    }
46
47    /// Computes `(a * b) mod modulus` and returns the normalized residue.
48    ///
49    /// Uses `i128` internally to reduce overflow risk for large `i64` inputs.
50    /// Returns `None` when `modulus <= 0`.
51    #[must_use]
52    pub fn mod_mul(a: i64, b: i64, modulus: i64) -> Option<i64> {
53        let modulus = checked_modulus(modulus)?;
54        let product = normalized_i128(a, modulus as i64)? * normalized_i128(b, modulus as i64)?;
55        i64::try_from(product.rem_euclid(modulus)).ok()
56    }
57}
58
59/// Modular exponentiation helpers.
60pub mod power {
61    use crate::{checked_modulus, normalized_i128};
62
63    /// Computes `base.pow(exponent) mod modulus` using exponentiation by squaring.
64    ///
65    /// Returns the normalized residue in `0..modulus`, or `None` when
66    /// `modulus <= 0`.
67    #[must_use]
68    pub fn mod_pow(base: i64, exponent: u64, modulus: i64) -> Option<i64> {
69        let modulus_i128 = checked_modulus(modulus)?;
70        let mut result = i128::from(1 % modulus);
71        let mut factor = normalized_i128(base, modulus)?;
72        let mut power = exponent;
73
74        while power > 0 {
75            if power & 1 == 1 {
76                result = (result * factor).rem_euclid(modulus_i128);
77            }
78
79            factor = (factor * factor).rem_euclid(modulus_i128);
80            power >>= 1;
81        }
82
83        i64::try_from(result).ok()
84    }
85}
86
87/// Modular inverse helpers.
88pub mod inverse {
89    use crate::{checked_modulus, normalized_i128};
90
91    /// Computes the multiplicative inverse of `value` modulo `modulus`.
92    ///
93    /// Returns `Some(inverse)` only when the inverse exists. The returned
94    /// residue is normalized to `0..modulus`. Returns `None` when
95    /// `modulus <= 0` or when `value` and `modulus` are not coprime.
96    #[must_use]
97    pub fn mod_inverse(value: i64, modulus: i64) -> Option<i64> {
98        let modulus_i128 = checked_modulus(modulus)?;
99        let value_i128 = normalized_i128(value, modulus)?;
100        let (gcd, coefficient, _) = extended_gcd(value_i128, modulus_i128);
101
102        (gcd == 1)
103            .then(|| coefficient.rem_euclid(modulus_i128))
104            .and_then(|inverse| i64::try_from(inverse).ok())
105    }
106
107    fn extended_gcd(a: i128, b: i128) -> (i128, i128, i128) {
108        let (mut old_r, mut r) = (a, b);
109        let (mut old_s, mut s) = (1_i128, 0_i128);
110        let (mut old_t, mut t) = (0_i128, 1_i128);
111
112        while r != 0 {
113            let quotient = old_r / r;
114
115            (old_r, r) = (r, old_r - quotient * r);
116            (old_s, s) = (s, old_s - quotient * s);
117            (old_t, t) = (t, old_t - quotient * t);
118        }
119
120        (old_r.abs(), old_s, old_t)
121    }
122}
123
124/// Modular congruence helpers.
125pub mod congruence {
126    use crate::arithmetic::mod_normalize;
127
128    /// Returns `true` when `a` and `b` are congruent modulo `modulus`.
129    ///
130    /// Returns `false` when `modulus <= 0`.
131    #[must_use]
132    pub fn is_congruent(a: i64, b: i64, modulus: i64) -> bool {
133        match (mod_normalize(a, modulus), mod_normalize(b, modulus)) {
134            (Some(left), Some(right)) => left == right,
135            _ => false,
136        }
137    }
138}
139
140pub use arithmetic::{mod_add, mod_mul, mod_normalize, mod_sub};
141pub use congruence::is_congruent;
142pub use inverse::mod_inverse;
143pub use power::mod_pow;
144
145/// A normalized modular residue paired with its positive modulus.
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub struct Modular {
148    value: i64,
149    modulus: i64,
150}
151
152impl Modular {
153    /// Creates a normalized modular value.
154    ///
155    /// Returns `None` when `modulus <= 0`.
156    #[must_use]
157    pub fn new(value: i64, modulus: i64) -> Option<Self> {
158        Some(Self {
159            value: mod_normalize(value, modulus)?,
160            modulus,
161        })
162    }
163
164    /// Returns the normalized residue in `0..modulus`.
165    #[must_use]
166    pub const fn value(self) -> i64 {
167        self.value
168    }
169
170    /// Returns the positive modulus carried by this value.
171    #[must_use]
172    pub const fn modulus(self) -> i64 {
173        self.modulus
174    }
175
176    /// Adds two modular values with the same modulus.
177    ///
178    /// Returns `None` when the moduli differ.
179    #[must_use]
180    pub fn add(self, other: Self) -> Option<Self> {
181        let modulus = self.same_modulus(other)?;
182        Self::new(mod_add(self.value, other.value, modulus)?, modulus)
183    }
184
185    /// Subtracts two modular values with the same modulus.
186    ///
187    /// Returns `None` when the moduli differ.
188    #[must_use]
189    pub fn sub(self, other: Self) -> Option<Self> {
190        let modulus = self.same_modulus(other)?;
191        Self::new(mod_sub(self.value, other.value, modulus)?, modulus)
192    }
193
194    /// Multiplies two modular values with the same modulus.
195    ///
196    /// Returns `None` when the moduli differ.
197    #[must_use]
198    pub fn mul(self, other: Self) -> Option<Self> {
199        let modulus = self.same_modulus(other)?;
200        Self::new(mod_mul(self.value, other.value, modulus)?, modulus)
201    }
202
203    /// Raises the modular value to `exponent` using modular exponentiation.
204    #[must_use]
205    pub fn pow(self, exponent: u64) -> Option<Self> {
206        Self::new(mod_pow(self.value, exponent, self.modulus)?, self.modulus)
207    }
208
209    /// Computes the multiplicative inverse when one exists.
210    #[must_use]
211    pub fn inverse(self) -> Option<Self> {
212        Self::new(mod_inverse(self.value, self.modulus)?, self.modulus)
213    }
214
215    fn same_modulus(self, other: Self) -> Option<i64> {
216        (self.modulus == other.modulus).then_some(self.modulus)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::{
223        Modular, is_congruent, mod_add, mod_inverse, mod_mul, mod_normalize, mod_pow, mod_sub,
224    };
225
226    #[test]
227    fn accepts_positive_modulus() {
228        assert_eq!(mod_normalize(0, 1), Some(0));
229        assert_eq!(mod_normalize(7, 5), Some(2));
230    }
231
232    #[test]
233    fn rejects_zero_modulus() {
234        assert_eq!(mod_normalize(3, 0), None);
235        assert_eq!(mod_add(1, 2, 0), None);
236        assert_eq!(mod_sub(1, 2, 0), None);
237        assert_eq!(mod_mul(1, 2, 0), None);
238        assert_eq!(mod_pow(2, 3, 0), None);
239        assert_eq!(mod_inverse(3, 0), None);
240        assert!(!is_congruent(1, 1, 0));
241    }
242
243    #[test]
244    fn rejects_negative_modulus() {
245        assert_eq!(mod_normalize(3, -5), None);
246        assert_eq!(mod_add(1, 2, -5), None);
247        assert_eq!(mod_sub(1, 2, -5), None);
248        assert_eq!(mod_mul(1, 2, -5), None);
249        assert_eq!(mod_pow(2, 3, -5), None);
250        assert_eq!(mod_inverse(3, -5), None);
251        assert!(!is_congruent(1, 1, -5));
252    }
253
254    #[test]
255    fn normalizes_positive_values() {
256        assert_eq!(mod_normalize(17, 5), Some(2));
257    }
258
259    #[test]
260    fn normalizes_negative_values() {
261        assert_eq!(mod_normalize(-1, 5), Some(4));
262        assert_eq!(mod_normalize(-13, 5), Some(2));
263    }
264
265    #[test]
266    fn adds_residues() {
267        assert_eq!(mod_add(4, 3, 5), Some(2));
268    }
269
270    #[test]
271    fn subtracts_residues() {
272        assert_eq!(mod_sub(2, 4, 5), Some(3));
273    }
274
275    #[test]
276    fn multiplies_residues() {
277        assert_eq!(mod_mul(4, 4, 5), Some(1));
278    }
279
280    #[test]
281    fn computes_modular_powers() {
282        assert_eq!(mod_pow(2, 10, 1_000), Some(24));
283    }
284
285    #[test]
286    fn handles_zero_exponent() {
287        assert_eq!(mod_pow(9, 0, 5), Some(1));
288        assert_eq!(mod_pow(9, 0, 1), Some(0));
289    }
290
291    #[test]
292    fn computes_existing_inverse() {
293        assert_eq!(mod_inverse(3, 11), Some(4));
294    }
295
296    #[test]
297    fn reports_missing_inverse() {
298        assert_eq!(mod_inverse(2, 4), None);
299    }
300
301    #[test]
302    fn checks_congruence() {
303        assert!(is_congruent(17, 5, 12));
304    }
305
306    #[test]
307    fn checks_non_congruence() {
308        assert!(!is_congruent(17, 6, 12));
309    }
310
311    #[test]
312    fn multiplies_large_values_with_i128_intermediate() {
313        let left = 3_037_000_500_i64;
314        let right = 3_037_000_500_i64;
315        let modulus = 97_i64;
316        let expected = (i128::from(left) * i128::from(right)).rem_euclid(i128::from(modulus));
317
318        assert_eq!(mod_mul(left, right, modulus), i64::try_from(expected).ok());
319    }
320
321    #[test]
322    fn constructs_and_operates_on_modular_values() {
323        let left = Modular::new(-1, 5).expect("valid modular value");
324        let right = Modular::new(3, 5).expect("valid modular value");
325        let different = Modular::new(1, 7).expect("valid modular value");
326
327        assert_eq!(left.value(), 4);
328        assert_eq!(left.modulus(), 5);
329        assert_eq!(left.add(right).map(Modular::value), Some(2));
330        assert_eq!(left.sub(right).map(Modular::value), Some(1));
331        assert_eq!(left.mul(right).map(Modular::value), Some(2));
332        assert_eq!(right.pow(4).map(Modular::value), Some(1));
333        assert_eq!(right.inverse().map(Modular::value), Some(2));
334        assert_eq!(left.add(different), None);
335    }
336}